Homework 08: Simulation-Based Inference (SBI)#
Simulation-based inference (SBI) is a family of methods for Bayesian inference when we can generate data from a simulator but do not want to rely on an analytically tractable likelihood. Instead of writing down a closed-form posterior, we simulate many synthetic datasets under different parameter settings and use those simulations to learn how observations constrain latent parameters.
This is especially useful in physics, where forward models are often natural to simulate even when exact likelihood-based inference is difficult. In this homework, you will work through a simple radioactive decay example and compare an exact reference posterior against an amortized neural approximation.
In a simple radioactive decay experiment. We observe repeated waiting times between decay events and want to infer the decay rate parameter lambda. This is a natural simulation-based inference setting: the simulator is easy to write, and it produces realistic stochastic measurements from a physics process.
In this homework, you will:
define a prior and a decay-time simulator
compute an exact posterior on a grid for comparison
train a small amortized posterior network
evaluate calibration on held-out simulations
generate posterior predictive samples of decay times
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
torch.manual_seed(7)
np.random.seed(7)
random.seed(7)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device:', device)
Problem 1: Prior, Decay-Time Simulator, and Summary Statistic#
Consider a simple radioactive decay model where the waiting time between decay events is exponentially distributed:
latent parameter:
lambda ~ Uniform(0.2, 2.0)observations:
t_i ~ Exponential(rate=lambda)
In an experiment, larger lambda means shorter typical waiting times. We will simulate many experiments, each consisting of n_obs waiting times.
Implement:
sample_prior(n)returning shape(n, 1)simulate_decay_times(lam, n_obs=20)returning shape(n, n_obs)compute_summary(t)returning the sample mean waiting time with shape(n, 1)
Goal: build the simulator interface used in the rest of the homework.
def sample_prior(n):
# YOUR CODE HERE
raise NotImplementedError('Implement sample_prior')
def simulate_decay_times(lam, n_obs=20):
# YOUR CODE HERE
raise NotImplementedError('Implement simulate_decay_times')
def compute_summary(t):
# YOUR CODE HERE
raise NotImplementedError('Implement compute_summary')
# Local checks
lam = sample_prior(8)
t = simulate_decay_times(lam, n_obs=20)
s = compute_summary(t)
assert lam.shape == (8, 1)
assert t.shape == (8, 20)
assert s.shape == (8, 1)
assert torch.all(lam >= 0.2) and torch.all(lam <= 2.0)
assert torch.all(t >= 0.0)
print('Problem 1 checks passed.')
Figure for Problem 1#
Why this plot:
Left: sampled decay rates from the prior
Right: example simulated waiting-time sequences and their sample-mean summaries
This connects the latent physics parameter to the observable data. Larger rates should tend to produce shorter waiting times.
lam_demo = sample_prior(6)
t_demo = simulate_decay_times(lam_demo, n_obs=20)
s_demo = compute_summary(t_demo)
fig, ax = plt.subplots(1, 2, figsize=(10, 3.5))
ax[0].hist(lam_demo[:, 0].detach().cpu().numpy(), bins=6, color='#4C78A8', alpha=0.85)
ax[0].set_title('Prior samples of decay rate')
ax[0].set_xlabel('lambda')
ax[0].set_ylabel('count')
for i in range(min(4, t_demo.shape[0])):
ax[1].plot(t_demo[i].detach().cpu().numpy(), marker='o', lw=1.2, alpha=0.8,
label='lambda=' + '{:.2f}'.format(lam_demo[i, 0].item()))
ax[1].axhline(s_demo[i, 0].item(), ls='--', lw=1)
ax[1].set_title('Simulated decay waiting times')
ax[1].set_xlabel('event index')
ax[1].set_ylabel('waiting time')
ax[1].legend(frameon=False, fontsize=8)
plt.tight_layout()
plt.show()
Problem 2: Exact Posterior on a Grid#
For exponential waiting times with known rate parameter lambda, the likelihood for observed waiting times t_obs is known. That means we can compute the exact posterior numerically on a grid and use it as a reference.
Assume a single observed dataset t_obs. Since the prior is uniform on [0.2, 2.0], the unnormalized posterior is proportional to the exponential likelihood inside that interval and zero outside.
Implement:
log_likelihood_grid(lam_grid, t_obs)posterior_from_loglik(loglik)that normalizes on the gridposterior_mean(lam_grid, posterior)
Goal: create a numerical benchmark posterior before training a neural approximation.
def log_likelihood_grid(lam_grid, t_obs):
# YOUR CODE HERE
raise NotImplementedError('Implement log_likelihood_grid')
def posterior_from_loglik(loglik):
# YOUR CODE HERE
raise NotImplementedError('Implement posterior_from_loglik')
def posterior_mean(lam_grid, posterior):
# YOUR CODE HERE
raise NotImplementedError('Implement posterior_mean')
# Local checks
t_obs = torch.tensor([0.2, 0.4, 0.1, 0.3, 0.5])
lam_grid = torch.linspace(0.2, 2.0, 301)
loglik = log_likelihood_grid(lam_grid, t_obs)
post = posterior_from_loglik(loglik)
pm = posterior_mean(lam_grid, post)
assert loglik.shape == lam_grid.shape
assert post.shape == lam_grid.shape
assert torch.isclose(post.sum(), torch.tensor(1.0), atol=1e-5)
assert 0.2 <= pm.item() <= 2.0
print('Problem 2 checks passed.')
Figure for Problem 2#
Why this plot:
It shows the exact posterior over the decay rate
lambdafor one observed decay experimentIt also marks the observed mean waiting time and the posterior mean estimate
This is the reference posterior that your amortized inference model should learn to approximate.
t_obs_demo = torch.tensor([0.2, 0.4, 0.1, 0.3, 0.5])
lam_grid_demo = torch.linspace(0.2, 2.0, 301)
loglik_demo = log_likelihood_grid(lam_grid_demo, t_obs_demo)
post_demo = posterior_from_loglik(loglik_demo)
pm_demo = posterior_mean(lam_grid_demo, post_demo)
plt.figure(figsize=(6.5, 3.8))
plt.plot(lam_grid_demo.detach().cpu().numpy(), post_demo.detach().cpu().numpy(), lw=2)
plt.axvline(1.0 / t_obs_demo.mean().item(), color='crimson', ls='--', label='1 / mean waiting time')
plt.axvline(pm_demo.item(), color='black', ls=':', label='posterior mean')
plt.xlabel('lambda')
plt.ylabel('posterior probability (grid)')
plt.title('Exact posterior for one decay experiment')
plt.legend(frameon=False)
plt.tight_layout()
plt.show()
Problem 3: Amortized Posterior Network#
In simulation-based inference, we often train a neural network that maps observations or summaries directly to an approximate posterior. This avoids solving a new inference problem from scratch for every new experiment.
In this problem, train a small network that takes the sample mean waiting time as input and predicts the parameters of a Gaussian approximation to the posterior over lambda:
posterior mean
mu(summary)posterior log standard deviation
log_sigma(summary)
Implement:
PosteriorNet.forward(summary)returning(mu, log_std)gaussian_nll(lam, mu, log_std)the training loop over simulated
(lambda, summary)pairs
Goal: build a minimal amortized neural posterior estimator.
class PosteriorNet(nn.Module):
def __init__(self, hidden_dim=32):
super().__init__()
self.net = nn.Sequential(
nn.Linear(1, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 2)
)
def forward(self, summary):
# YOUR CODE HERE
raise NotImplementedError('Implement PosteriorNet.forward')
def gaussian_nll(lam, mu, log_std):
# YOUR CODE HERE
raise NotImplementedError('Implement gaussian_nll')
model = PosteriorNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-2)
for step in range(400):
# YOUR CODE HERE
# 1) sample lambda from the prior
# 2) simulate decay times and compute summaries
# 3) run the network and minimize Gaussian negative log likelihood
pass
# Local checks
summary_test = torch.tensor([[0.5], [1.0]], dtype=torch.float32).to(device)
mu_test, log_std_test = model(summary_test)
assert mu_test.shape == (2, 1)
assert log_std_test.shape == (2, 1)
print('Problem 3 checks passed.')
Figure for Problem 3#
Why this plot:
It shows how the network maps the observed mean waiting time to a posterior mean estimate
It also shows the predicted posterior uncertainty as a function of the summary statistic
Physically, larger mean waiting times should generally correspond to smaller decay rates.
summary_grid = torch.linspace(0.2, 3.0, 200).unsqueeze(1).to(device)
with torch.no_grad():
mu_grid, log_std_grid = model(summary_grid)
std_grid = log_std_grid.exp().detach().cpu().numpy()
mu_grid = mu_grid.detach().cpu().numpy()
summary_grid_np = summary_grid.detach().cpu().numpy()
fig, ax = plt.subplots(1, 2, figsize=(10, 3.5))
ax[0].plot(summary_grid_np, mu_grid, lw=2, label='network posterior mean')
ax[0].plot(summary_grid_np, 1.0 / summary_grid_np, ls='--', color='gray', label='1 / sample mean')
ax[0].set_title('Predicted posterior mean vs sample mean')
ax[0].set_xlabel('summary = mean waiting time')
ax[0].set_ylabel('predicted mu for lambda')
ax[0].legend(frameon=False)
ax[1].plot(summary_grid_np, std_grid, lw=2, color='#F58518', label='network posterior std')
ax[1].set_title('Predicted posterior std vs sample mean')
ax[1].set_xlabel('summary = mean waiting time')
ax[1].set_ylabel('predicted std')
ax[1].legend(frameon=False)
plt.tight_layout()
plt.show()
Problem 4: Calibration and Coverage#
A posterior approximation should provide uncertainty estimates that are not only sharp but also realistic.
In this problem, estimate empirical coverage of nominal 68% intervals on held-out simulated decay experiments:
simulate many
(lambda, t)pairscompute summaries
use the network to predict a Gaussian posterior
build central 68% intervals and check how often the true
lambdafalls inside
Implement empirical_coverage(model, n_eval=2000, n_obs=20).
Goal: compare nominal uncertainty against observed frequency coverage.
def empirical_coverage(model, n_eval=2000, n_obs=20):
# YOUR CODE HERE
raise NotImplementedError('Implement empirical_coverage')
coverage = empirical_coverage(model, n_eval=300, n_obs=20)
assert 0.0 <= coverage <= 1.0
print('Empirical 68% coverage:', coverage)
print('Problem 4 checks passed.')
Figure for Problem 4#
Why this plot:
It compares predicted 68% posterior intervals against the true latent decay rate on held-out simulations
The title reports empirical coverage
This gives a direct visual test of whether the learned uncertainty is calibrated.
n_plot = 80
lam_eval = sample_prior(n_plot).to(device)
t_eval = simulate_decay_times(lam_eval, n_obs=20)
s_eval = compute_summary(t_eval)
with torch.no_grad():
mu_eval, log_std_eval = model(s_eval)
std_eval = log_std_eval.exp()
lo = (mu_eval - std_eval).detach().cpu().numpy()[:, 0]
hi = (mu_eval + std_eval).detach().cpu().numpy()[:, 0]
mu_np = mu_eval.detach().cpu().numpy()[:, 0]
lam_np = lam_eval.detach().cpu().numpy()[:, 0]
inside = (lam_np >= lo) & (lam_np <= hi)
coverage_plot = inside.mean()
order = np.argsort(mu_np)
idx = np.arange(n_plot)
plt.figure(figsize=(8.5, 4))
plt.vlines(idx, lo[order], hi[order], color=np.where(inside[order], '#54A24B', '#E45756'), alpha=0.8)
plt.scatter(idx, lam_np[order], s=12, color='black', label='true lambda')
plt.scatter(idx, mu_np[order], s=12, color='#4C78A8', label='predicted mean')
plt.title('Predicted 68% intervals on held-out decay experiments (coverage={:.3f})'.format(coverage_plot))
plt.xlabel('held-out simulation (sorted by predicted mean)')
plt.ylabel('lambda')
plt.legend(frameon=False)
plt.tight_layout()
plt.show()
Problem 5: Posterior Predictive Simulation#
A posterior is scientifically useful because it lets us propagate uncertainty back through the simulator and ask whether inferred parameters reproduce the observed experiment.
In this problem, implement posterior predictive simulation for one observed decay dataset:
compute the observed summary
get the Gaussian posterior approximation from the network
draw
lambdasamples from that approximate posteriorsimulate new decay-time datasets and compare predictive summaries to the observed summary
Implement:
sample_posterior_gaussian(mu, log_std, n_samples)posterior_predictive_summaries(model, t_obs, n_samples=200, n_obs=20)
Goal: connect posterior inference to posterior predictive checking for a physics experiment.
def sample_posterior_gaussian(mu, log_std, n_samples):
# YOUR CODE HERE
raise NotImplementedError('Implement sample_posterior_gaussian')
def posterior_predictive_summaries(model, t_obs, n_samples=200, n_obs=20):
# YOUR CODE HERE
raise NotImplementedError('Implement posterior_predictive_summaries')
t_obs = simulate_decay_times(torch.tensor([[1.0]]), n_obs=20)[0]
pred_summaries = posterior_predictive_summaries(model, t_obs, n_samples=100, n_obs=20)
assert pred_summaries.shape == (100, 1)
print('Problem 5 checks passed.')
Figure for Problem 5#
Why this plot:
It compares the observed mean waiting time to the distribution of posterior predictive mean waiting times
It tests whether the inferred posterior places weight on decay rates that reproduce data similar to the experiment you saw
This is a standard posterior predictive diagnostic and a useful end-to-end SBI sanity check.
t_obs_demo = simulate_decay_times(torch.tensor([[1.0]]), n_obs=20)[0]
obs_summary = compute_summary(t_obs_demo.unsqueeze(0))[0, 0].item()
pred_summaries = posterior_predictive_summaries(model, t_obs_demo, n_samples=300, n_obs=20)
plt.figure(figsize=(6.5, 3.8))
plt.hist(pred_summaries[:, 0].detach().cpu().numpy(), bins=25, color='#72B7B2', alpha=0.85)
plt.axvline(obs_summary, color='black', ls='--', lw=2, label='observed mean waiting time')
plt.xlabel('posterior predictive sample mean waiting time')
plt.ylabel('count')
plt.title('Posterior predictive summaries for decay experiment')
plt.legend(frameon=False)
plt.tight_layout()
plt.show()
Acknowledgments#
Initial version: Mark Neubauer
© Copyright 2026