Homework 08: Simulation-Based Inference (SBI)#

Simulation-based inference (SBI) is a family of methods for Bayesian inference when we can generate data from a simulator but do not want to rely on an analytically tractable likelihood. Instead of writing down a closed-form posterior, we simulate many synthetic datasets under different parameter settings and use those simulations to learn how observations constrain latent parameters.

This is especially useful in physics, where forward models are often natural to simulate even when exact likelihood-based inference is difficult. In this homework, you will work through a simple radioactive decay example and compare an exact reference posterior against an amortized neural approximation.

In a simple radioactive decay experiment. We observe repeated waiting times between decay events and want to infer the decay rate parameter lambda. This is a natural simulation-based inference setting: the simulator is easy to write, and it produces realistic stochastic measurements from a physics process.

In this homework, you will:

  • define a prior and a decay-time simulator

  • compute an exact posterior on a grid for comparison

  • train a small amortized posterior network

  • evaluate calibration on held-out simulations

  • generate posterior predictive samples of decay times

import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.manual_seed(7)
np.random.seed(7)
random.seed(7)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device:', device)

Problem 1: Prior, Decay-Time Simulator, and Summary Statistic#

Consider a simple radioactive decay model where the waiting time between decay events is exponentially distributed:

  • latent parameter: lambda ~ Uniform(0.2, 2.0)

  • observations: t_i ~ Exponential(rate=lambda)

In an experiment, larger lambda means shorter typical waiting times. We will simulate many experiments, each consisting of n_obs waiting times.

Implement:

  • sample_prior(n) returning shape (n, 1)

  • simulate_decay_times(lam, n_obs=20) returning shape (n, n_obs)

  • compute_summary(t) returning the sample mean waiting time with shape (n, 1)

Goal: build the simulator interface used in the rest of the homework.

def sample_prior(n):
    # YOUR CODE HERE
    raise NotImplementedError('Implement sample_prior')


def simulate_decay_times(lam, n_obs=20):
    # YOUR CODE HERE
    raise NotImplementedError('Implement simulate_decay_times')


def compute_summary(t):
    # YOUR CODE HERE
    raise NotImplementedError('Implement compute_summary')


# Local checks
lam = sample_prior(8)
t = simulate_decay_times(lam, n_obs=20)
s = compute_summary(t)

assert lam.shape == (8, 1)
assert t.shape == (8, 20)
assert s.shape == (8, 1)
assert torch.all(lam >= 0.2) and torch.all(lam <= 2.0)
assert torch.all(t >= 0.0)
print('Problem 1 checks passed.')

Figure for Problem 1#

Why this plot:

  • Left: sampled decay rates from the prior

  • Right: example simulated waiting-time sequences and their sample-mean summaries

This connects the latent physics parameter to the observable data. Larger rates should tend to produce shorter waiting times.

lam_demo = sample_prior(6)
t_demo = simulate_decay_times(lam_demo, n_obs=20)
s_demo = compute_summary(t_demo)

fig, ax = plt.subplots(1, 2, figsize=(10, 3.5))
ax[0].hist(lam_demo[:, 0].detach().cpu().numpy(), bins=6, color='#4C78A8', alpha=0.85)
ax[0].set_title('Prior samples of decay rate')
ax[0].set_xlabel('lambda')
ax[0].set_ylabel('count')

for i in range(min(4, t_demo.shape[0])):
    ax[1].plot(t_demo[i].detach().cpu().numpy(), marker='o', lw=1.2, alpha=0.8,
               label='lambda=' + '{:.2f}'.format(lam_demo[i, 0].item()))
    ax[1].axhline(s_demo[i, 0].item(), ls='--', lw=1)
ax[1].set_title('Simulated decay waiting times')
ax[1].set_xlabel('event index')
ax[1].set_ylabel('waiting time')
ax[1].legend(frameon=False, fontsize=8)
plt.tight_layout()
plt.show()

Problem 2: Exact Posterior on a Grid#

For exponential waiting times with known rate parameter lambda, the likelihood for observed waiting times t_obs is known. That means we can compute the exact posterior numerically on a grid and use it as a reference.

Assume a single observed dataset t_obs. Since the prior is uniform on [0.2, 2.0], the unnormalized posterior is proportional to the exponential likelihood inside that interval and zero outside.

Implement:

  • log_likelihood_grid(lam_grid, t_obs)

  • posterior_from_loglik(loglik) that normalizes on the grid

  • posterior_mean(lam_grid, posterior)

Goal: create a numerical benchmark posterior before training a neural approximation.

def log_likelihood_grid(lam_grid, t_obs):
    # YOUR CODE HERE
    raise NotImplementedError('Implement log_likelihood_grid')


def posterior_from_loglik(loglik):
    # YOUR CODE HERE
    raise NotImplementedError('Implement posterior_from_loglik')


def posterior_mean(lam_grid, posterior):
    # YOUR CODE HERE
    raise NotImplementedError('Implement posterior_mean')


# Local checks
t_obs = torch.tensor([0.2, 0.4, 0.1, 0.3, 0.5])
lam_grid = torch.linspace(0.2, 2.0, 301)
loglik = log_likelihood_grid(lam_grid, t_obs)
post = posterior_from_loglik(loglik)
pm = posterior_mean(lam_grid, post)

assert loglik.shape == lam_grid.shape
assert post.shape == lam_grid.shape
assert torch.isclose(post.sum(), torch.tensor(1.0), atol=1e-5)
assert 0.2 <= pm.item() <= 2.0
print('Problem 2 checks passed.')

Figure for Problem 2#

Why this plot:

  • It shows the exact posterior over the decay rate lambda for one observed decay experiment

  • It also marks the observed mean waiting time and the posterior mean estimate

This is the reference posterior that your amortized inference model should learn to approximate.

t_obs_demo = torch.tensor([0.2, 0.4, 0.1, 0.3, 0.5])
lam_grid_demo = torch.linspace(0.2, 2.0, 301)
loglik_demo = log_likelihood_grid(lam_grid_demo, t_obs_demo)
post_demo = posterior_from_loglik(loglik_demo)
pm_demo = posterior_mean(lam_grid_demo, post_demo)

plt.figure(figsize=(6.5, 3.8))
plt.plot(lam_grid_demo.detach().cpu().numpy(), post_demo.detach().cpu().numpy(), lw=2)
plt.axvline(1.0 / t_obs_demo.mean().item(), color='crimson', ls='--', label='1 / mean waiting time')
plt.axvline(pm_demo.item(), color='black', ls=':', label='posterior mean')
plt.xlabel('lambda')
plt.ylabel('posterior probability (grid)')
plt.title('Exact posterior for one decay experiment')
plt.legend(frameon=False)
plt.tight_layout()
plt.show()

Problem 3: Amortized Posterior Network#

In simulation-based inference, we often train a neural network that maps observations or summaries directly to an approximate posterior. This avoids solving a new inference problem from scratch for every new experiment.

In this problem, train a small network that takes the sample mean waiting time as input and predicts the parameters of a Gaussian approximation to the posterior over lambda:

  • posterior mean mu(summary)

  • posterior log standard deviation log_sigma(summary)

Implement:

  • PosteriorNet.forward(summary) returning (mu, log_std)

  • gaussian_nll(lam, mu, log_std)

  • the training loop over simulated (lambda, summary) pairs

Goal: build a minimal amortized neural posterior estimator.

class PosteriorNet(nn.Module):
    def __init__(self, hidden_dim=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 2)
        )

    def forward(self, summary):
        # YOUR CODE HERE
        raise NotImplementedError('Implement PosteriorNet.forward')


def gaussian_nll(lam, mu, log_std):
    # YOUR CODE HERE
    raise NotImplementedError('Implement gaussian_nll')


model = PosteriorNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-2)

for step in range(400):
    # YOUR CODE HERE
    # 1) sample lambda from the prior
    # 2) simulate decay times and compute summaries
    # 3) run the network and minimize Gaussian negative log likelihood
    pass


# Local checks
summary_test = torch.tensor([[0.5], [1.0]], dtype=torch.float32).to(device)
mu_test, log_std_test = model(summary_test)

assert mu_test.shape == (2, 1)
assert log_std_test.shape == (2, 1)
print('Problem 3 checks passed.')

Figure for Problem 3#

Why this plot:

  • It shows how the network maps the observed mean waiting time to a posterior mean estimate

  • It also shows the predicted posterior uncertainty as a function of the summary statistic

Physically, larger mean waiting times should generally correspond to smaller decay rates.

summary_grid = torch.linspace(0.2, 3.0, 200).unsqueeze(1).to(device)
with torch.no_grad():
    mu_grid, log_std_grid = model(summary_grid)
std_grid = log_std_grid.exp().detach().cpu().numpy()
mu_grid = mu_grid.detach().cpu().numpy()
summary_grid_np = summary_grid.detach().cpu().numpy()

fig, ax = plt.subplots(1, 2, figsize=(10, 3.5))
ax[0].plot(summary_grid_np, mu_grid, lw=2, label='network posterior mean')
ax[0].plot(summary_grid_np, 1.0 / summary_grid_np, ls='--', color='gray', label='1 / sample mean')
ax[0].set_title('Predicted posterior mean vs sample mean')
ax[0].set_xlabel('summary = mean waiting time')
ax[0].set_ylabel('predicted mu for lambda')
ax[0].legend(frameon=False)

ax[1].plot(summary_grid_np, std_grid, lw=2, color='#F58518', label='network posterior std')
ax[1].set_title('Predicted posterior std vs sample mean')
ax[1].set_xlabel('summary = mean waiting time')
ax[1].set_ylabel('predicted std')
ax[1].legend(frameon=False)

plt.tight_layout()
plt.show()

Problem 4: Calibration and Coverage#

A posterior approximation should provide uncertainty estimates that are not only sharp but also realistic.

In this problem, estimate empirical coverage of nominal 68% intervals on held-out simulated decay experiments:

  • simulate many (lambda, t) pairs

  • compute summaries

  • use the network to predict a Gaussian posterior

  • build central 68% intervals and check how often the true lambda falls inside

Implement empirical_coverage(model, n_eval=2000, n_obs=20).

Goal: compare nominal uncertainty against observed frequency coverage.

def empirical_coverage(model, n_eval=2000, n_obs=20):
    # YOUR CODE HERE
    raise NotImplementedError('Implement empirical_coverage')


coverage = empirical_coverage(model, n_eval=300, n_obs=20)
assert 0.0 <= coverage <= 1.0
print('Empirical 68% coverage:', coverage)
print('Problem 4 checks passed.')

Figure for Problem 4#

Why this plot:

  • It compares predicted 68% posterior intervals against the true latent decay rate on held-out simulations

  • The title reports empirical coverage

This gives a direct visual test of whether the learned uncertainty is calibrated.

n_plot = 80
lam_eval = sample_prior(n_plot).to(device)
t_eval = simulate_decay_times(lam_eval, n_obs=20)
s_eval = compute_summary(t_eval)
with torch.no_grad():
    mu_eval, log_std_eval = model(s_eval)
std_eval = log_std_eval.exp()
lo = (mu_eval - std_eval).detach().cpu().numpy()[:, 0]
hi = (mu_eval + std_eval).detach().cpu().numpy()[:, 0]
mu_np = mu_eval.detach().cpu().numpy()[:, 0]
lam_np = lam_eval.detach().cpu().numpy()[:, 0]
inside = (lam_np >= lo) & (lam_np <= hi)
coverage_plot = inside.mean()

order = np.argsort(mu_np)
idx = np.arange(n_plot)
plt.figure(figsize=(8.5, 4))
plt.vlines(idx, lo[order], hi[order], color=np.where(inside[order], '#54A24B', '#E45756'), alpha=0.8)
plt.scatter(idx, lam_np[order], s=12, color='black', label='true lambda')
plt.scatter(idx, mu_np[order], s=12, color='#4C78A8', label='predicted mean')
plt.title('Predicted 68% intervals on held-out decay experiments (coverage={:.3f})'.format(coverage_plot))
plt.xlabel('held-out simulation (sorted by predicted mean)')
plt.ylabel('lambda')
plt.legend(frameon=False)
plt.tight_layout()
plt.show()

Problem 5: Posterior Predictive Simulation#

A posterior is scientifically useful because it lets us propagate uncertainty back through the simulator and ask whether inferred parameters reproduce the observed experiment.

In this problem, implement posterior predictive simulation for one observed decay dataset:

  • compute the observed summary

  • get the Gaussian posterior approximation from the network

  • draw lambda samples from that approximate posterior

  • simulate new decay-time datasets and compare predictive summaries to the observed summary

Implement:

  • sample_posterior_gaussian(mu, log_std, n_samples)

  • posterior_predictive_summaries(model, t_obs, n_samples=200, n_obs=20)

Goal: connect posterior inference to posterior predictive checking for a physics experiment.

def sample_posterior_gaussian(mu, log_std, n_samples):
    # YOUR CODE HERE
    raise NotImplementedError('Implement sample_posterior_gaussian')


def posterior_predictive_summaries(model, t_obs, n_samples=200, n_obs=20):
    # YOUR CODE HERE
    raise NotImplementedError('Implement posterior_predictive_summaries')


t_obs = simulate_decay_times(torch.tensor([[1.0]]), n_obs=20)[0]
pred_summaries = posterior_predictive_summaries(model, t_obs, n_samples=100, n_obs=20)

assert pred_summaries.shape == (100, 1)
print('Problem 5 checks passed.')

Figure for Problem 5#

Why this plot:

  • It compares the observed mean waiting time to the distribution of posterior predictive mean waiting times

  • It tests whether the inferred posterior places weight on decay rates that reproduce data similar to the experiment you saw

This is a standard posterior predictive diagnostic and a useful end-to-end SBI sanity check.

t_obs_demo = simulate_decay_times(torch.tensor([[1.0]]), n_obs=20)[0]
obs_summary = compute_summary(t_obs_demo.unsqueeze(0))[0, 0].item()
pred_summaries = posterior_predictive_summaries(model, t_obs_demo, n_samples=300, n_obs=20)

plt.figure(figsize=(6.5, 3.8))
plt.hist(pred_summaries[:, 0].detach().cpu().numpy(), bins=25, color='#72B7B2', alpha=0.85)
plt.axvline(obs_summary, color='black', ls='--', lw=2, label='observed mean waiting time')
plt.xlabel('posterior predictive sample mean waiting time')
plt.ylabel('count')
plt.title('Posterior predictive summaries for decay experiment')
plt.legend(frameon=False)
plt.tight_layout()
plt.show()

Acknowledgments#

  • Initial version: Mark Neubauer

© Copyright 2026