Keyboard shortcuts

Press โ† or โ†’ to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

๐ŸŽผ Fugue User Guide

Fugue Logo

A production-ready, monadic probabilistic programming library for Rust

Write elegant probabilistic programs by composing Model values in direct style; execute them with pluggable interpreters and state-of-the-art inference algorithms.

Rust Crates.io Dev Docs User Docs License: MIT CI codecov Downloads Zotero Discord Ask DeepWiki

Supported Rust: 1.70+ โ€ข Platforms: Linux / macOS / Windows โ€ข Crate: fugue-ppl on crates.io


๐Ÿ‘‹ Welcome

Check out these resources to get started:

About Fugue

  • ๐Ÿงฉ Monadic PPL: Compose probabilistic programs using pure functional abstractions
  • ๐Ÿ”’ Type-Safe Distributions: 10+ built-in probability distributions with natural return types
  • ๐Ÿ“Š Multiple Inference Methods: MCMC, SMC, Variational Inference, ABC
  • ๐Ÿ” Comprehensive Diagnostics: R-hat convergence, effective sample size, validation
  • ๐Ÿš€ Production Ready: Numerically stable algorithms with memory optimization
  • โœจ Ergonomic Macros: Do-notation (prob!), vectorization (plate!), addressing (addr!)

Installation

[dependencies]
fugue-ppl = "0.1.0"

๐Ÿ” More Resources

Getting Started with Fugue

Welcome to Fugue, a type-safe probabilistic programming library for Rust! This guide will get you building Bayesian models in just 15-20 minutes.

Note

What You'll Learn

By the end of this section, you'll understand:

  • How to install and set up Fugue
  • Core concepts of probabilistic programming
  • How Fugue's type system prevents common errors
  • How to run basic Bayesian inference

Time Investment: ~15-20 minutes total

Learning Path

We recommend following this path for the best learning experience:

flowchart LR
    A[Installation<br/>2 min] --> B[Your First Model<br/>5 min]
    B --> C[Understanding Models<br/>8 min]
    C --> D[Running Inference<br/>5 min]
    D --> E[Complete Tutorials<br/>45-60 min each]

Quick Start

If you're impatient and want to see Fugue in action immediately:

# Create a new project
cargo new my_bayesian_project
cd my_bayesian_project

# Add Fugue
cargo add fugue-ppl rand

# Copy our "Hello, Probabilistic World!" example into src/main.rs
# (See Installation section)

# Run it!
cargo run

What Makes Fugue Different?

๐Ÿ”’ Type Safety First

Unlike other probabilistic programming libraries, Fugue preserves natural types:

// In Fugue โœ…
let coin: bool = sample(addr!("coin"), Bernoulli::new(0.5).unwrap());
let count: u64 = sample(addr!("events"), Poisson::new(3.0).unwrap());
let category: usize = sample(addr!("choice"), Categorical::uniform(5).unwrap());

// Other PPLs โŒ
let coin: f64 = sample("coin", Bernoulli(0.5));  // Returns 0.0 or 1.0
let count: f64 = sample("events", Poisson(3.0)); // Need to cast to int
// let category: f64 = sample("choice", Categorical([...])); // Risky indexing

๐Ÿš€ Zero-Cost Abstractions

Models compile to efficient code with no runtime overhead.

๐Ÿงฐ Composable Architecture

Separate model specification from execution strategy through handlers.

๐Ÿ“Š Production Ready

Built-in diagnostics, memory optimization, and error handling.

Architecture Overview

Fugue's modular design separates concerns cleanly:

graph TB
    subgraph "Your Code"
        M[Model Definition]
        D[Data & Observations]
    end

    subgraph "Core System"
        C[Distributions & Types]
        H[Handlers & Interpreters]
        T[Traces & Memory]
    end

    subgraph "Inference Engines"
        MCMC[MCMC Sampling]
        SMC[Particle Filtering]
        VI[Variational Inference]
        ABC[ABC Methods]
    end

    M --> C
    D --> C
    C --> H
    H --> T
    T --> MCMC
    T --> SMC
    T --> VI
    T --> ABC

The Big Picture

Probabilistic Programming lets you:

  1. Model uncertainty and relationships in data
  2. Condition on observations to learn parameters
  3. Infer posterior distributions and make predictions
  4. Quantify uncertainty in your conclusions

Fugue makes this safe, fast, and composable in Rust.

Next Steps

Ready to dive in?

Tip

Start Here!

Begin with Installation to get Fugue running on your system.

Already have Rust installed? Skip ahead to Your First Model to start building probabilistic programs right away!

After completing Getting Started, explore:


Prerequisites: Basic Rust knowledge (variables, functions, cargo commands)

Installation

Getting Fugue set up in your Rust project takes just 2 minutes. Let's get you running!

Note

Prerequisites

Fugue requires Rust 1.70+. If you don't have Rust installed:

# Install Rust via rustup
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh

# Update to latest stable
rustup update stable

Adding Fugue to Your Project

New Project

cargo new my_probabilistic_project
cd my_probabilistic_project

Add Fugue to your Cargo.toml:

[dependencies]
fugue-ppl = "0.1.0"
rand = "0.8"  # For random number generation

Existing Project

Add Fugue to your existing Cargo.toml:

[dependencies]
fugue-ppl = "0.1.0"
rand = "0.8"

Or use cargo add:

cargo add fugue-ppl rand

Verification: "Hello, Probabilistic World!"

Let's verify your installation with a simple example that showcases Fugue's type safety.

Create or replace src/main.rs:

use fugue::*;
use rand::rngs::StdRng;
use rand::SeedableRng;

fn main() {
    println!("๐ŸŽฒ Hello, Probabilistic World!");

    // Create a simple model: flip a biased coin
    let coin_model = sample(addr!("coin"), Bernoulli::new(0.7).unwrap());

    // Run the model with a seeded RNG for reproducible results
    let mut rng = StdRng::seed_from_u64(42);
    let (is_heads, trace) = runtime::handler::run(
        runtime::interpreters::PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        coin_model,
    );

    // Print the result - notice it's a bool, not a float!
    let result = if is_heads { "Heads" } else { "Tails" };
    println!("๐Ÿช™ Coin flip result: {}", result);
    println!("๐Ÿ“Š Log probability: {:.4}", trace.total_log_weight());

    // Demonstrate type safety with different distributions
    let mut rng = StdRng::seed_from_u64(123);

    // Count events - returns u64 directly
    let (event_count, _) = runtime::handler::run(
        runtime::interpreters::PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        sample(addr!("events"), Poisson::new(3.5).unwrap()),
    );
    println!("๐ŸŽฏ Event count: {} (type: u64)", event_count);

    // Choose category - returns usize for safe indexing
    let options = vec!["red", "green", "blue"];
    let (category_idx, _) = runtime::handler::run(
        runtime::interpreters::PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        sample(addr!("color"), Categorical::uniform(3).unwrap()),
    );
    println!("๐ŸŽจ Chosen color: {} (safe indexing!)", options[category_idx]);

    println!("โœ… Fugue is working correctly!");
}

Run it to verify everything works:

cargo run

You should see output like:

๐ŸŽฒ Hello, Probabilistic World!
๐Ÿช™ Coin flip result: Heads
๐Ÿ“Š Log probability: -0.3567
๐ŸŽฏ Event count: 4 (type: u64)
๐ŸŽจ Chosen color: blue (safe indexing!)
โœ… Fugue is working correctly!

Tip

Type Safety in Action!

Notice how each distribution returns its natural type:

  • Bernoulli โ†’ bool (not f64)
  • Poisson โ†’ u64 (not f64)
  • Categorical โ†’ usize (not f64)

This prevents entire classes of runtime errors!

IDE Setup

VS Code

Install the rust-analyzer extension for the best development experience:

  1. Open VS Code
  2. Go to Extensions (Ctrl+Shift+X)
  3. Search for "rust-analyzer"
  4. Install the official rust-analyzer extension

Other IDEs

  • IntelliJ/CLion: Install the Rust plugin
  • Vim/Neovim: Use coc-rust-analyzer or native LSP
  • Emacs: Use lsp-mode with rust-analyzer

Optional: Running Examples

Fugue comes with comprehensive examples to explore:

# Clone the repository to access examples
git clone https://github.com/your-org/fugue-ppl
cd fugue-ppl

# List available examples
ls examples/

# Run a simple example
cargo run --example gaussian_mean -- --obs 2.5 --seed 42

# Try a more complex one
cargo run --example working_with_distributions

Troubleshooting

Common Issues

Build fails with dependency errors:

# Make sure you're using Rust 1.70+
rustc --version

# Update your dependencies
cargo update

Examples don't run:

# Make sure you're in the project root directory
pwd

# Check example names
ls examples/

IDE doesn't provide completions:

  • Make sure rust-analyzer is installed and running
  • Try restarting your IDE after installing dependencies
  • Check that your Cargo.toml has the correct dependencies

Getting Help

If you encounter issues:

  1. Check the GitHub Issues
  2. Review the examples for working code
  3. Read the API documentation

Next Steps

Installation complete! ๐ŸŽ‰

Ready to build your first probabilistic model? โ†’ Your First Model

Want to explore examples first? โ†’ Complete Tutorials


Time: ~2 minutes โ€ข Next: Your First Model

Your First Model

Now that Fugue is installed, let's build your first probabilistic model step by step. We'll start simple and gradually introduce the key concepts.

Note

Learning Goals

In 5 minutes, you'll understand:

  • How to create deterministic and probabilistic models
  • The role of addresses in probabilistic programming
  • How to condition models on observed data
  • Fugue's type-safe approach to distributions

Time: ~5 minutes

Step 1: The Simplest Model

Let's start with the simplest possible model - one that always returns the same value:

use fugue::*;

fn constant_model() -> Model<f64> {
    pure(42.0)
}

This model always returns 42.0. The pure function creates a deterministic Model<f64>.

Key insight: Models are descriptions of computations, not the computations themselves.

Step 2: Adding Randomness

Now let's add some randomness by sampling from a probability distribution:

use fugue::*;

fn random_model() -> Model<f64> {
    sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
}

Note

New Concepts

  • sample() - Draw a random value from a distribution
  • addr!("x") - Give this random choice a unique name/address
  • Normal::new(0.0, 1.0).unwrap() - Standard normal distribution (mean=0, std=1)
  • .unwrap() - Fugue uses safe constructors that validate parameters

Step 3: Running Your Model

To actually get values from your model, you need to "run" it with a handler:

use fugue::*;
use rand::rngs::StdRng;
use rand::SeedableRng;

fn main() {
    let model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
    
    // Create a seeded random number generator
    let mut rng = StdRng::seed_from_u64(42);
    
    // Run the model with PriorHandler (forward sampling)
    let (value, trace) = runtime::handler::run(
        runtime::interpreters::PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        model,
    );
    
    println!("Sampled value: {:.4}", value);
    println!("Log probability: {:.4}", trace.total_log_weight());
}

Running this outputs:

Sampled value: 1.0175
Log probability: -0.9189

Tip

Understanding the Output

  • value - The random sample from our distribution
  • trace - Records what happened during execution (choices made, probabilities)
  • log_probability - How likely this particular execution was

Step 4: Type Safety in Action

Fugue's type safety really shines with discrete distributions:

use fugue::*;
use rand::rngs::StdRng;
use rand::SeedableRng;

fn type_safe_examples() {
    let mut rng = StdRng::seed_from_u64(42);
    
    // Flip a coin - returns bool directly!
    let (is_heads, _) = runtime::handler::run(
        runtime::interpreters::PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        sample(addr!("coin"), Bernoulli::new(0.6).unwrap()),
    );
    
    // Natural boolean usage - no comparisons needed!
    let outcome = if is_heads { "Heads" } else { "Tails" };
    println!("Coin flip: {}", outcome);
    
    // Count events - returns u64 directly!
    let (count, _) = runtime::handler::run(
        runtime::interpreters::PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        sample(addr!("events"), Poisson::new(3.0).unwrap()),
    );
    
    println!("Event count: {} (no casting needed!)", count);
    
    // Choose category - returns usize for safe indexing!
    let colors = vec!["red", "green", "blue", "yellow"];
    let (idx, _) = runtime::handler::run(
        runtime::interpreters::PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        sample(addr!("color"), Categorical::uniform(4).unwrap()),
    );
    
    println!("Chosen color: {}", colors[idx]); // Safe indexing!
}

Warning

Contrast with Other PPLs

In most probabilistic programming languages:

# Other PPLs - everything returns float
coin = sample("coin", Bernoulli(0.6))  # Returns 0.0 or 1.0 
if coin == 1.0:  # Need comparison โŒ
    ...

count = sample("events", Poisson(3.0))  # Returns float
count_int = int(count)  # Need casting โŒ

idx = sample("color", Categorical([...]))  # Returns float  
colors[int(idx)]  # Risky casting and indexing โŒ

Fugue prevents these errors at compile time! โœ…

Step 5: Your First Bayesian Model

Now let's create a simple Bayesian model that learns from data:

use fugue::*;
use rand::rngs::StdRng;
use rand::SeedableRng;

fn estimate_mean(observation: f64) -> Model<f64> {
    sample(addr!("mu"), Normal::new(0.0, 2.0).unwrap())  // Prior belief
        .bind(move |mu| {
            observe(addr!("y"), Normal::new(mu, 1.0).unwrap(), observation)  // Likelihood
                .map(move |_| mu)  // Return the parameter
        })
}

fn main() {
    let observation = 3.0;  // We observed a value of 3.0
    let model = estimate_mean(observation);
    
    let mut rng = StdRng::seed_from_u64(42);
    let (estimated_mu, trace) = runtime::handler::run(
        runtime::interpreters::PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        model,
    );
    
    println!("Observation: {}", observation);
    println!("Estimated mean: {:.4}", estimated_mu);
    println!("Log probability: {:.4}", trace.total_log_weight());
}

Note

What Just Happened?

This is a complete Bayesian inference setup:

  1. Prior: sample(addr!("mu"), Normal::new(0.0, 2.0).unwrap())

    • Our initial belief about the mean (uncertain, centered at 0)
  2. Likelihood: observe(addr!("y"), Normal::new(mu, 1.0).unwrap(), observation)

    • How likely our observation is, given different values of mu
  3. Bind: .bind(move |mu| ...)

    • Use the sampled mu in the rest of the model
  4. Return: .map(move |_| mu)

    • Return the parameter we want to estimate

Understanding Model Composition

Fugue models compose using two key operations:

map - Transform Values

let doubled = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
    .map(|x| x * 2.0);  // Apply function to the result

bind - Dependent Computations

let dependent = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
    .bind(|x| sample(addr!("y"), Normal::new(x, 0.5).unwrap()));  // y depends on x

Tip

Mental Model

  • map = "transform the output"
  • bind = "use the output in the next step"

These are the fundamental building blocks for complex probabilistic models!

Key Takeaways

After working through these examples, you should understand:

โœ… Models are values: Model<T> represents a probabilistic computation
โœ… Safe constructors: Distributions use .new().unwrap() for parameter validation
โœ… Type safety: Distributions return natural types (bool, u64, f64)
โœ… Addressing: addr!("name") gives names to random variables
โœ… Execution: Models need handlers to run and produce values
โœ… Composition: Use map and bind to build complex models from simple parts

What's Next?

You can now build and run basic probabilistic models! ๐ŸŽ‰

Continue your journey:

Tip

Next Steps

Ready for complete projects?


Time: ~5 minutes โ€ข Next: Understanding Models

Understanding Models

Now that you can build basic models, let's understand the key concepts that make Fugue powerful. This will give you the mental framework to build sophisticated probabilistic programs.

Note

Learning Goals

In 8 minutes, you'll understand:

  • Why models are separate from execution
  • How addressing enables advanced inference
  • The monadic structure and composition patterns
  • When to use map vs bind vs pure

Time: ~8 minutes

The Big Picture: Models vs Execution

One of Fugue's key insights is separating model specification from execution:

graph LR
    subgraph "Your Code"
        M[Model Definition<br/>What to compute]
    end
    
    subgraph "Runtime System"  
        H[Handler<br/>How to execute]
        T[Trace<br/>What happened]
    end
    
    M --> H
    H --> T

Why this matters: The same model can be executed in different ways:

  • Forward sampling (generate data from priors)
  • Conditioning (inference given observations)
  • Replay (MCMC proposals)
  • Scoring (compute probabilities)

Addresses: The Key to Advanced Inference

Every sample and observe site needs a unique address:

use fugue::*;

// Good addressing โœ…
let model = sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap())
    .bind(|mu| sample(addr!("sigma"), LogNormal::new(0.0, 0.5).unwrap())
        .bind(move |sigma| {
            observe(addr!("y1"), Normal::new(mu, sigma).unwrap(), 2.1);
            observe(addr!("y2"), Normal::new(mu, sigma).unwrap(), 1.9);
            pure((mu, sigma))
        }));

Note

Why Addresses Matter

Addresses enable advanced inference by allowing algorithms to:

  • Identify which random choices to modify (MCMC)
  • Replay specific execution paths
  • Condition on subsets of variables
  • Debug model behavior by inspecting traces

Without addresses, you can only do forward sampling!

Addressing Patterns

Simple names for scalar parameters:

let mu = sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap());

Indexed addresses for collections:

for i in 0..10 {
    let x_i = sample(addr!("x", i), Normal::new(mu, 1.0).unwrap());
}

Scoped addresses for hierarchical models:

let encoder_z = sample(scoped_addr!("encoder", "z"), dist);
let decoder_z = sample(scoped_addr!("decoder", "z"), dist);

Warning

Address Anti-Patterns

// โŒ DON'T: Random or non-deterministic addresses  
let addr = format!("param_{}", rng.gen::<u32>());  // NEVER!

// โŒ DON'T: Reuse addresses for different purposes
sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
sample(addr!("x"), Bernoulli::new(0.5).unwrap());  // Collision!

// โŒ DON'T: Missing addresses in loops
for i in 0..10 {
    sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());  // All same address!
}

Model Composition: The Monadic Structure

Fugue models follow a monadic pattern that makes complex models composable:

Three Fundamental Operations

graph TB
    subgraph "Model Building Blocks"
        P[pure: A โ†’ ModelโŸจAโŸฉ<br/>Lift values into models]
        M[map: ModelโŸจAโŸฉ โ†’ โŸจA โ†’ BโŸฉ โ†’ ModelโŸจBโŸฉ<br/>Transform outputs]  
        B[bind: ModelโŸจAโŸฉ โ†’ โŸจA โ†’ ModelโŸจBโŸฉโŸฉ โ†’ ModelโŸจBโŸฉ<br/>Chain dependent computations]
    end

pure - Lift Values

Use pure to inject deterministic values into the probabilistic context:

use fugue::*;

// Lift a constant
let constant = pure(42.0);

// Lift computed values
let computed = pure(data.iter().sum::<f64>() / data.len() as f64);

map - Transform Outputs

Use map when you want to transform the result without adding randomness:

use fugue::*;

// Transform a single sample
let squared = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
    .map(|x| x * x);

// Combine multiple values  
let sum = zip(
    sample(addr!("a"), Normal::new(0.0, 1.0).unwrap()),
    sample(addr!("b"), Normal::new(0.0, 1.0).unwrap())
).map(|(a, b)| a + b);

bind - Chain Dependent Computations

Use bind when the next random choice depends on a previous one:

use fugue::*;

// Dependent sampling: variance depends on mean
let hierarchical = sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap())
    .bind(|mu| sample(addr!("sigma"), LogNormal::new(mu.abs().ln(), 0.1).unwrap())
        .bind(move |sigma| sample(addr!("y"), Normal::new(mu, sigma).unwrap())));

// Conditional branching: choice affects distribution  
let mixture = sample(addr!("component"), Bernoulli::new(0.5).unwrap())
    .bind(|component| {
        if component {
            sample(addr!("value"), Normal::new(-2.0, 1.0).unwrap())
        } else {
            sample(addr!("value"), Normal::new(2.0, 1.0).unwrap())  
        }
    });

Tip

When to Use What?

  • pure - Inject constants or computed values
  • map - Transform outputs, no new randomness
  • bind - Next step depends on previous random result

Rule of thumb: Use the least powerful operation that works!

Advanced Composition Patterns

Building Collections

use fugue::*;

// Fixed-size collection
let samples = traverse_vec((0..10).collect(), |i| {
    sample(addr!("x", i), Normal::new(0.0, 1.0).unwrap())
});

// Data-driven collection
let observations = traverse_vec(data, |datum| {
    observe(addr!("y", datum.id), Normal::new(datum.mu, 1.0).unwrap(), datum.value)
});

Conditional Models

use fugue::*;

fn model_selection(data: &[f64]) -> Model<String> {
    sample(addr!("use_robust"), Bernoulli::new(0.2).unwrap())
        .bind(|use_robust| {
            if use_robust {
                // Robust model with t-distribution
                sample(addr!("df"), LogNormal::new(1.0, 0.5).unwrap())
                    .bind(|df| {
                        // Hypothetical t-distribution sampling
                        pure("robust".to_string())
                    })
            } else {
                // Standard normal model
                sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap())
                    .map(|_| "normal".to_string())
            }
        })
}

Hierarchical Structure

use fugue::*;

fn hierarchical_model(groups: Vec<Vec<f64>>) -> Model<Vec<f64>> {
    // Global hyperparameters
    sample(addr!("global_mu"), Normal::new(0.0, 2.0).unwrap())
        .bind(|global_mu| {
            sample(addr!("global_sigma"), LogNormal::new(0.0, 0.5).unwrap())
                .bind(move |global_sigma| {
                    // Group-level parameters  
                    traverse_vec(groups.into_iter().enumerate().collect(), move |(g, data)| {
                        sample(addr!("group_mu", g), Normal::new(global_mu, global_sigma).unwrap())
                            .bind(move |group_mu| {
                                // Individual observations
                                traverse_vec(data.into_iter().enumerate().collect(), move |(i, y)| {
                                    observe(addr!("y", g, i), Normal::new(group_mu, 1.0).unwrap(), y)
                                }).map(move |_| group_mu)
                            })
                    })
                })
        })
}

Mental Models for Success

Think in Terms of Generative Stories

Ask yourself: "How could this data have been generated?"

// Story: "Each person has a skill level, and their performance 
//         on each task reflects that skill plus task-specific noise"

let model = sample(addr!("skill"), Normal::new(100.0, 15.0).unwrap())  // Person's skill
    .bind(|skill| {
        traverse_vec(tasks, move |task| {
            let difficulty = task.difficulty;
            let expected_score = skill - difficulty;
            observe(addr!("score", task.id), Normal::new(expected_score, 5.0).unwrap(), task.actual_score)
        }).map(move |_| skill)
    });

Separate What from How

  • What: Model describes relationships and distributions
  • How: Handler determines execution strategy (sampling, inference, etc.)
// The SAME model...
let model = build_regression_model(&data);

// Can be executed different ways:
let (sample, _) = run(PriorHandler { /*...*/ }, model.clone());      // Forward sampling
let (_, scored_trace) = run(ScoreGivenTrace { /*...*/ }, model.clone()); // Compute likelihood  
let mcmc_chain = adaptive_mcmc_chain(rng, || model.clone(), 1000, 500); // MCMC inference

Key Takeaways

You now understand the foundational concepts:

โœ… Separation of Concerns: Models describe computations, handlers execute them
โœ… Addressing Strategy: Unique, stable addresses enable advanced inference
โœ… Monadic Composition: pure, map, bind build complex models from simple parts
โœ… Compositional Patterns: Collections, conditionals, and hierarchical structures
โœ… Generative Thinking: Model the data generation process

What's Next?

You have the conceptual foundation to build sophisticated models! ๐ŸŽ‰

Tip

Next Steps

Continue Getting Started:

Ready for Real Projects:


Time: ~8 minutes โ€ข Next: Running Inference

Running Inference

You now know how to build probabilistic models. But models alone don't give you answers - you need inference to extract insights from them. Let's explore Fugue's inference algorithms!

Note

Learning Goals

In 5 minutes, you'll understand:

  • What inference is and why you need it
  • Fugue's main inference algorithms (MCMC, SMC, VI, ABC)
  • When to use each algorithm
  • How to run inference and interpret results

Time: ~5 minutes

What is Inference?

Inference is the process of learning about model parameters after seeing data. In Bayesian terms:

graph LR
    subgraph "Before Data"
        P["Prior Beliefs<br/>p(theta)"]
    end

    subgraph "Observing Data"
        L["Likelihood<br/>p(y|theta)"]
        D["Data<br/>yโ‚, yโ‚‚, ..."]
    end

    subgraph "After Data"
        Post["Posterior Beliefs<br/>p(theta|y)"]
    end

    P --> Post
    L --> Post
    D --> Post

The Challenge

Most real models don't have analytical solutions. We need algorithms to approximate the posterior distribution.

Fugue's Inference Arsenal

1. MCMC (Markov Chain Monte Carlo) ๐Ÿฅ‡

Best for: Most general-purpose Bayesian inference

How it works: Generates samples that approximate the posterior distribution

use fugue::*;
use rand::rngs::StdRng;
use rand::SeedableRng;

fn coin_bias_model(heads: u64, total: u64) -> Model<f64> {
    sample(addr!("bias"), Beta::new(1.0, 1.0).unwrap())  // Prior
        .bind(move |bias| {
            observe(addr!("heads"), Binomial::new(total, bias).unwrap(), heads)  // Likelihood
                .map(move |_| bias)
        })
}

fn main() {
    let mut rng = StdRng::seed_from_u64(42);

    // Run adaptive MCMC
    let samples = inference::mh::adaptive_mcmc_chain(
        &mut rng,
        || coin_bias_model(7, 10),  // 7 heads out of 10 flips
        1000,  // number of samples
        500,   // warmup samples
    );

    // Extract bias estimates
    let bias_samples: Vec<f64> = samples.iter()
        .filter_map(|(_, trace)| trace.get_f64(&addr!("bias")))
        .collect();

    let mean_bias = bias_samples.iter().sum::<f64>() / bias_samples.len() as f64;
    println!("Estimated bias: {:.3}", mean_bias);

    // Compute 90% credible interval
    let mut sorted = bias_samples.clone();
    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
    let lower = sorted[(0.05 * sorted.len() as f64) as usize];
    let upper = sorted[(0.95 * sorted.len() as f64) as usize];
    println!("90% credible interval: [{:.3}, {:.3}]", lower, upper);
}

When to use MCMC:

  • โœ… Want exact posterior samples
  • โœ… Moderate number of parameters (< 100)
  • โœ… Can afford computation time
  • โœ… Model evaluation is reasonably fast

2. SMC (Sequential Monte Carlo) ๐ŸŽฏ

Best for: Sequential data and online learning

How it works: Uses particles to approximate the posterior, good for streaming data

use fugue::*;
use rand::rngs::StdRng;
use rand::SeedableRng;

fn main() {
    let mut rng = StdRng::seed_from_u64(42);

    // Generate particles from prior
    let particles = inference::smc::smc_prior_particles(
        &mut rng,
        1000,  // number of particles
        || coin_bias_model(7, 10),
    );

    println!("Generated {} particles", particles.len());

    // Compute weighted posterior mean
    let total_weight: f64 = particles.iter().map(|p| p.weight).sum();
    let weighted_mean: f64 = particles.iter()
        .filter_map(|p| {
            p.trace.get_f64(&addr!("bias"))
                .map(|bias| bias * p.weight)
        })
        .sum::<f64>() / total_weight;

    println!("Weighted posterior mean: {:.3}", weighted_mean);

    // Check effective sample size
    let weights: Vec<f64> = particles.iter().map(|p| p.weight).collect();
    let ess = 1.0 / weights.iter().map(|w| w * w).sum::<f64>();
    println!("Effective sample size: {:.1}", ess);
}

When to use SMC:

  • โœ… Sequential/streaming data
  • โœ… Online inference needed
  • โœ… Many discrete latent variables
  • โœ… Want to visualize inference process

3. Variational Inference (VI) โšก

Best for: Fast approximate inference with many parameters

How it works: Finds the best approximation within a family of simple distributions

use fugue::*;
use rand::rngs::StdRng;
use rand::SeedableRng;

fn main() {
    let mut rng = StdRng::seed_from_u64(42);

    // Estimate ELBO (Evidence Lower BOund)
    let elbo = inference::vi::estimate_elbo(
        &mut rng,
        || coin_bias_model(7, 10),
        100,  // number of samples for estimation
    );

    println!("ELBO estimate: {:.3}", elbo);

    // For more sophisticated VI, you'd set up a variational guide
    // and optimize it (see the VI tutorial for details)
}

When to use VI:

  • โœ… Need fast approximate inference
  • โœ… Many parameters (> 100)
  • โœ… Can accept approximation error
  • โœ… Want predictable runtime

4. ABC (Approximate Bayesian Computation) ๐ŸŽฒ

Best for: Models where likelihood is intractable or expensive

How it works: Simulation-based inference using distance between simulated and observed data

use fugue::*;
use rand::rngs::StdRng;
use rand::SeedableRng;

fn main() {
    let mut rng = StdRng::seed_from_u64(42);

    // ABC with summary statistics
    let observed_summary = 0.7; // 7/10 = 0.7 success rate
    let samples = inference::abc::abc_scalar_summary(
        &mut rng,
        || sample(addr!("bias"), Beta::new(1.0, 1.0).unwrap()), // Prior only
        |trace| trace.get_f64(&addr!("bias")).unwrap_or(0.0), // Extract bias
        observed_summary,  // Target summary statistic
        0.1,              // Tolerance
        1000,             // Max samples to try
    );

    println!("ABC accepted {} samples", samples.len());

    if !samples.is_empty() {
        let abc_estimates: Vec<f64> = samples.iter()
            .filter_map(|trace| trace.get_f64(&addr!("bias")))
            .collect();
        let abc_mean = abc_estimates.iter().sum::<f64>() / abc_estimates.len() as f64;
        println!("ABC estimated bias: {:.3}", abc_mean);
    }
}

When to use ABC:

  • โœ… Likelihood is intractable or very expensive
  • โœ… Can simulate from the model easily
  • โœ… Have good summary statistics
  • โœ… Can tolerate approximation error

Algorithm Comparison

MethodSpeedAccuracyUse Case
MCMC๐ŸŒ Slow๐ŸŽฏ ExactGeneral-purpose, exact inference
SMC๐Ÿƒ Medium๐ŸŽฏ GoodSequential data, online learning
VI๐Ÿš€ Fastโš ๏ธ ApproximateLarge models, fast approximate inference
ABC๐ŸŒ Slowโš ๏ธ ApproximateIntractable likelihoods

Practical Inference Workflow

Here's a typical workflow for real inference:

use fugue::*;
use rand::rngs::StdRng;
use rand::SeedableRng;

fn inference_workflow() {
    let mut rng = StdRng::seed_from_u64(42);

    // 1. Define your model
    let model = || coin_bias_model(17, 25);  // 17 heads out of 25 flips

    // 2. Run inference (adaptive MCMC is often a good default)
    let samples = inference::mh::adaptive_mcmc_chain(
        &mut rng,
        model,
        2000,  // samples
        1000,  // warmup
    );

    // 3. Extract parameter values
    let bias_samples: Vec<f64> = samples.iter()
        .filter_map(|(_, trace)| trace.get_f64(&addr!("bias")))
        .collect();

    // 4. Compute summary statistics
    let mean = bias_samples.iter().sum::<f64>() / bias_samples.len() as f64;
    let variance = bias_samples.iter()
        .map(|&x| (x - mean).powi(2))
        .sum::<f64>() / (bias_samples.len() - 1) as f64;
    let std_dev = variance.sqrt();

    println!("Posterior Summary:");
    println!("  Mean: {:.3}", mean);
    println!("  Std Dev: {:.3}", std_dev);

    // 5. Check convergence (effective sample size)
    let ess = inference::diagnostics::effective_sample_size(&bias_samples);
    println!("  Effective Sample Size: {:.1}", ess);

    if ess > 100.0 {
        println!("  โœ… Good mixing!");
    } else {
        println!("  โš ๏ธ Poor mixing - consider more samples");
    }

    // 6. Make predictions
    println!("\nPredictions:");
    println!("  P(bias > 0.5) = {:.2}",
        bias_samples.iter().filter(|&&b| b > 0.5).count() as f64 / bias_samples.len() as f64);
}

Choosing the Right Algorithm

Decision Tree

graph TD
    A[Need inference?] -->|Yes| B[Real-time/online?]
    A -->|No| Z[Use PriorHandler<br/>for forward sampling]

    B -->|Yes| SMC[SMC]
    B -->|No| C[Likelihood tractable?]

    C -->|No| ABC[ABC]
    C -->|Yes| D[Many parameters?]

    D -->|Yes > 100| VI[Variational Inference]
    D -->|No < 100| E[Need exact samples?]

    E -->|Yes| MCMC[MCMC]
    E -->|No| VI2[VI for speed]

Rules of Thumb

  1. Start with MCMC for most problems - it's the most general
  2. Use SMC if you have sequential/streaming data
  3. Use VI if you need speed and can accept approximation
  4. Use ABC only when likelihood is truly intractable

Key Takeaways

You now know how to extract insights from your models:

โœ… Inference Purpose: Learn parameters from data using Bayesian updating
โœ… Algorithm Options: MCMC, SMC, VI, ABC each have their strengths
โœ… Practical Workflow: Define model โ†’ Run inference โ†’ Extract parameters โ†’ Check diagnostics
โœ… Algorithm Selection: Choose based on problem characteristics and requirements

What's Next?

You've completed Getting Started! ๐ŸŽ‰

Tip

Ready for Real Applications?

Complete Tutorials - End-to-end projects with real-world applications:

How-To Guides - Specific techniques and best practices:


Time: ~5 minutes โ€ข Next: Complete Tutorials

How-To Guides

The How-To guides provide practical, task-oriented instructions for accomplishing specific goals with Fugue. Unlike tutorials that teach concepts step-by-step, these guides assume you understand the basics and want to solve particular problems efficiently.

Guide Overview

These guides are designed to be example-first and immediately actionable. Each guide includes comprehensive, executable code examples that serve as the canonical source of truth for the patterns they demonstrate.

๐Ÿ“Š Working with Distributions

When to use: You need to understand Fugue's type-safe distribution system, parameter validation, or probability calculations.

What you'll learn:

  • Type-safe distribution usage (bool, u64, usize, f64 return types)
  • Parameter validation and error handling
  • Continuous vs discrete distribution patterns
  • Categorical distributions and safe indexing
  • Distribution composition and practical modeling
  • Probability calculations and testing strategies

Key patterns: Natural return types, parameter validation, distribution testing


๐Ÿ—๏ธ Building Complex Models

When to use: You want to compose sophisticated probabilistic models using Fugue's macro system and advanced patterns.

What you'll learn:

  • prob! macro for do-notation style probabilistic programming
  • plate! macro for vectorized operations and array processing
  • scoped_addr! macro for hierarchical address management
  • Model composition and sequential dependencies
  • Hierarchical modeling patterns
  • Bayesian linear regression and mixture models

Key patterns: Monadic composition, vectorization, hierarchical structure


โšก Optimizing Performance

When to use: Your models need to run efficiently in production or handle large-scale inference workloads.

What you'll learn:

  • Memory pooling with TracePool and PooledPriorHandler
  • Numerical stability with log-space computations
  • Efficient trace construction with TraceBuilder
  • Copy-on-write traces for MCMC optimization
  • Batch processing patterns
  • Performance monitoring and measurement

Key patterns: Memory optimization, numerical stability, batch processing


๐Ÿ” Debugging Models

When to use: Your probabilistic models aren't behaving as expected, or you need to diagnose issues in inference.

What you'll learn:

  • Comprehensive trace inspection and analysis
  • Type-safe value access with proper error handling
  • Model validation against analytical solutions
  • Safe vs strict handler usage for error resilience
  • MCMC diagnostics and convergence assessment
  • Performance and memory debugging techniques

Key patterns: Trace analysis, validation testing, diagnostic metrics


๐ŸŽ›๏ธ Custom Handlers

When to use: You need to extend Fugue's execution model with custom behavior, logging, or specialized inference algorithms.

What you'll learn:

  • Complete Handler trait implementation
  • Decorator pattern for cross-cutting concerns
  • Stateful handlers for analytics and monitoring
  • Conditional filtering and value modification
  • Performance monitoring integration
  • Custom inference algorithms (MCMC-like patterns)
  • Handler composition and chaining

Key patterns: Algebraic effects, decorator composition, custom inference


๐Ÿš€ Production Deployment

When to use: You're deploying probabilistic models to production environments and need reliability, monitoring, and operational excellence.

What you'll learn:

  • Error handling and graceful degradation patterns
  • Circuit breaker implementation for fault tolerance
  • Configuration management for multiple environments
  • Comprehensive metrics and Prometheus integration
  • Automated health checks and system validation
  • Input validation and security best practices
  • Deployment strategies (blue-green, canary, rolling)

Key patterns: Fault tolerance, observability, operational readiness


How to Use These Guides

๐ŸŽฏ Task-Oriented Approach

Each guide focuses on solving specific problems:

  • Need to understand distributions? โ†’ Start with "Working with Distributions"
  • Building complex models? โ†’ "Building Complex Models" has the macros and patterns
  • Performance issues? โ†’ "Optimizing Performance" covers memory and numerical techniques
  • Models not working? โ†’ "Debugging Models" provides diagnostic approaches
  • Need custom behavior? โ†’ "Custom Handlers" shows how to extend the system
  • Going to production? โ†’ "Production Deployment" covers operational concerns

๐Ÿ“š Progressive Complexity

The guides are ordered by increasing complexity:

graph TD
    A["๐Ÿ“Š Working with<br/>Distributions"] --> B["๐Ÿ—๏ธ Building Complex<br/>Models"]
    B --> C["โšก Optimizing<br/>Performance"]
    A --> D["๐Ÿš€ Production<br/>Deployment"]
    E["๐Ÿ” Debugging<br/>Models"] --> F["๐ŸŽ›๏ธ Custom<br/>Handlers"]
    F --> D
    C --> D
  • Start with distributions and model building
  • Add performance optimization when needed
  • Use debugging when things go wrong
  • Extend with custom handlers for specialized needs
  • Deploy with production patterns for real applications

๐Ÿ”„ Cross-References

Guides frequently reference each other:

  • Performance optimization builds on complex models
  • Debugging techniques apply to all model types
  • Custom handlers can incorporate performance patterns
  • Production deployment uses patterns from all previous guides

๐Ÿ“ Example-First Philosophy

Every guide follows the same structure:

  1. Executable examples as the source of truth
  2. Comprehensive code snippets with anchor tags
  3. Practical explanations of when and how to use patterns
  4. Testing strategies to verify correctness
  5. Best practices learned from real-world usage

Code Examples

All code examples in these guides are:

  • โœ… Executable: Run with cargo run --example <guide_name>
  • โœ… Tested: Verified with cargo test --examples
  • โœ… Documented: Included via {{#include}} from example files
  • โœ… Comprehensive: Cover real-world usage patterns
  • โœ… Type-Safe: Leverage Rust's type system throughout

Quick Reference

TaskGuideKey Patterns
Understand distributionsWorking with DistributionsType safety, validation
Build complex modelsBuilding Complex ModelsMacros, composition
Optimize performanceOptimizing PerformanceMemory pooling, numerics
Debug model issuesDebugging ModelsTrace analysis, diagnostics
Extend functionalityCustom HandlersHandler patterns, decorators
Deploy to productionProduction DeploymentFault tolerance, monitoring

Integration with Other Documentation

๐Ÿš€ Getting Started โ†’ How-To Guides

After completing the Getting Started tutorials, use these guides to solve specific problems in your own projects.

๐Ÿ“– How-To Guides โ†’ Tutorials

For deeper conceptual understanding, see the Tutorials section which provides comprehensive examples of complete applications.

๐Ÿ”ง How-To Guides โ†’ API Reference

For detailed API documentation, consult the API Reference section for specific functions and types mentioned in these guides.

Contributing

When adding new How-To guides:

  1. Create executable examples first in examples/
  2. Use anchor tags to mark code sections for inclusion
  3. Write the guide using {{#include}} for all code snippets
  4. Test thoroughly with both cargo test and mdbook test
  5. Update this README with the new guide information

Each guide should solve specific, practical problems that users commonly encounter when working with Fugue in real applications.

Working with Distributions

Fugue's type-safe distribution system represents a principled approach to probabilistic programming, eliminating entire classes of runtime errors through rigorous type theory while preserving the full expressiveness of statistical modeling. This guide demonstrates the mathematical foundations and practical applications of Fugue's distribution architecture.

Type Theory Foundation

Fugue's distribution system is grounded in dependent type theory, where each distribution is parameterized not just by its parameters , but by its support type . This ensures that and eliminates the need for runtime type checking or unsafe casting operations.

Type Safety in Practice

Traditional probabilistic programming libraries return f64 for everything, leading to casting overhead and runtime errors. Fugue distributions return their natural types:

    // Demonstrate natural return types
    let coin = Bernoulli::new(0.5).unwrap();
    let flip: bool = coin.sample(&mut rng); // Natural boolean

    if flip {
        println!("๐Ÿช™ Coin flip result: Heads!");
    } else {
        println!("๐Ÿช™ Coin flip result: Tails!");
    }

No casting, no comparisons with floating-point valuesโ€”just natural boolean logic.

Continuous Distributions

Continuous distributions in Fugue model phenomena over uncountable domains . The probability density function satisfies the normalization condition:

For computational stability, Fugue operates in log-space by default, computing to avoid numerical underflow:

Log-Space Computation

Working directly with densities can cause severe numerical issues when . Fugue's log_prob() method computes , which remains numerically stable even for extreme tail probabilities.

    // Working with continuous distributions
    let standard_normal = Normal::new(0.0, 1.0).unwrap();
    let sample: f64 = standard_normal.sample(&mut rng);
    println!("๐Ÿ“Š Standard normal sample: {:.3}", sample);

    // Compute log-probability density
    let log_density = standard_normal.log_prob(&0.0); // Peak of standard normal
    println!(
        "๐Ÿ“ˆ Log-density at x=0: {:.3} (peak of standard normal)",
        log_density
    );

    // Custom parameters
    let measurement_model = Normal::new(10.0, 0.5).unwrap();
    let measurement = measurement_model.sample(&mut rng);
    println!("๐Ÿ”ฌ Sensor measurement (ฮผ=10.0, ฯƒ=0.5): {:.3}", measurement);

Key Points:

  • sample() returns f64 for direct arithmetic
  • log_prob() computes log-density (avoids numerical underflow)
  • Parameter validation happens at construction time

Tip

Always work with log-probabilities for numerical stability. Only convert to regular probabilities when necessary for interpretation.

Discrete Distributions

Discrete distributions operate over countable support sets or finite sets. The probability mass function satisfies:

Fugue enforces this constraint at construction time and leverages natural integer types to eliminate precision loss from floating-point representation:

Integer Precision Preservation

Unlike floating-point representations that can introduce rounding errors, Fugue's native u64 and usize types preserve exact integer values. This is crucial for count data where must remain precisely representable.

    // Working with discrete distributions

    // Count data
    let event_rate = Poisson::new(3.0).unwrap();
    let count: u64 = event_rate.sample(&mut rng);
    println!("๐Ÿ“… Event count (ฮป=3.0): {} events", count);

    // Log-probability mass
    let prob_3_events = event_rate.log_prob(&3);
    println!(
        "๐ŸŽฏ Log-probability of exactly 3 events: {:.3}",
        prob_3_events
    );

    // Use counts directly in calculations
    let total_cost = count * 50; // Direct arithmetic with u64
    println!("๐Ÿ’ฐ Total cost ({} events ร— $50): ${}", count, total_cost);

Benefits:

  • u64 counts support direct arithmetic without casting
  • No precision loss from floating-point representations
  • Natural integration with Rust's type system

Safe Categorical Sampling

Categorical distributions return usize for safe array indexing:

    // Safe categorical sampling
    let choices = vec![0.3, 0.5, 0.2]; // Three categories
    let categorical = Categorical::new(choices).unwrap();
    let selected: usize = categorical.sample(&mut rng);

    // Safe array indexing (no bounds checking needed)
    let options = ["Option A", "Option B", "Option C"];
    println!(
        "๐ŸŽฒ Categorical choice (weights: 0.3, 0.5, 0.2): {}",
        options[selected]
    );

    // Uniform categorical
    let uniform_choice = Categorical::uniform(5).unwrap();
    let idx: usize = uniform_choice.sample(&mut rng);
    println!("๐ŸŽฏ Uniform random index (0-4): {}", idx);

Note

The usize return type eliminates bounds checking errorsโ€”the sampled index is guaranteed to be valid for the probability vector length.

Parameter Validation

Fugue enforces mathematical constraints through compile-time and runtime validation. Each distribution family has a parameter space defining valid configurations:

graph TD
    A[Parameter Input ฮธ] --> B{ฮธ โˆˆ ฮ˜?}
    B -->|Yes| C[Distribution Construction]
    B -->|No| D[ValidationError]
    C --> E[Type-Safe Sampling]
    D --> F[Early Failure Detection]

Constraint Examples:

  • Normal Distribution: requires
  • Beta Distribution: requires
  • Categorical Distribution: and
    // Distribution parameter validation

    // This will return an error
    match Normal::new(0.0, -1.0) {
        Ok(_) => println!("โœ… Normal(ฮผ=0.0, ฯƒ=-1.0) created successfully"),
        Err(e) => println!("โŒ Normal(ฮผ=0.0, ฯƒ=-1.0) failed: {:?}", e),
    }

    // Beta distribution parameters must be positive
    match Beta::new(0.0, 1.0) {
        Ok(_) => println!("โœ… Beta(ฮฑ=0.0, ฮฒ=1.0) created successfully"),
        Err(e) => println!("โŒ Beta(ฮฑ=0.0, ฮฒ=1.0) failed: {:?}", e),
    }

    // Poisson rate must be non-negative
    match Poisson::new(-1.0) {
        Ok(_) => println!("โœ… Poisson(ฮป=-1.0) created successfully"),
        Err(e) => println!("โŒ Poisson(ฮป=-1.0) failed: {:?}", e),
    }

Validation Rules:

  • Normal: ฯƒ > 0
  • Beta: ฮฑ > 0, ฮฒ > 0
  • Poisson: ฮป โ‰ฅ 0
  • Categorical: probabilities sum to 1, all non-negative

Storing Mixed Distributions

Use trait objects for collections of distributions with the same return type:

    // Storing different distributions together
    let continuous_dists: Vec<Box<dyn Distribution<f64>>> = vec![
        Normal::new(0.0, 1.0).unwrap().clone_box(),
        Beta::new(2.0, 5.0).unwrap().clone_box(),
        Uniform::new(-1.0, 1.0).unwrap().clone_box(),
    ];

    // Sample from each
    for (i, dist) in continuous_dists.iter().enumerate() {
        let sample = dist.sample(&mut rng);
        let dist_name = match i {
            0 => "Normal(0,1)",
            1 => "Beta(2,5)",
            2 => "Uniform(-1,1)",
            _ => "Unknown",
        };
        println!("๐Ÿ“ฆ {} sample: {:.3}", dist_name, sample);
    }

This enables dynamic distribution selection and model composition patterns.

Practical Modeling Patterns

Common modeling scenarios demonstrate natural type usage:

    // Practical modeling examples

    // Model a sensor with noise
    let true_temperature = 20.5; // True value
    let sensor_noise = Normal::new(0.0, 0.2).unwrap(); // Measurement error
    let measured_temp = true_temperature + sensor_noise.sample(&mut rng);
    println!(
        "๐ŸŒก๏ธ  True temperature: {:.2}ยฐC โ†’ Measured: {:.2}ยฐC",
        true_temperature, measured_temp
    );

    // Count model for arrivals
    let arrival_rate = Poisson::new(2.5).unwrap(); // 2.5 arrivals per hour
    let hourly_arrivals = arrival_rate.sample(&mut rng);
    println!(
        "๐Ÿšช Expected arrivals: 2.5/hour โ†’ Actual: {} arrivals",
        hourly_arrivals
    );

    // Decision model
    let decision_prob = 0.7;
    let decision = Bernoulli::new(decision_prob).unwrap();
    let will_buy = decision.sample(&mut rng);
    if will_buy {
        println!("๐Ÿ›’ Customer decision (p=0.7): Will make a purchase");
    } else {
        println!("๐Ÿšถ Customer decision (p=0.7): Will not purchase");
    }

Each distribution serves its natural domain without artificial conversions.

Working with Log-Probabilities

Logarithmic probability computation is essential for numerical stability in probabilistic programming. Consider the log-sum-exp operation for computing:

where . This formulation prevents overflow when is large:

Numerical Stability Theorem

Direct computation of where will underflow to machine zero for moderate . Log-space computation of remains stable for arbitrarily small probabilities, preserving up to 15-17 digits of precision in IEEE 754 double precision.

    // Working with log-probabilities

    let normal = Normal::new(100.0, 15.0).unwrap();

    // Multiple observations
    let observations = vec![98.5, 102.1, 99.8, 101.5, 97.2];
    let mut total_log_prob = 0.0;

    println!(
        "๐Ÿ“‹ Evaluating {} observations under Normal(ฮผ=100.0, ฯƒ=15.0):",
        observations.len()
    );
    for obs in &observations {
        let log_p = normal.log_prob(obs);
        total_log_prob += log_p;
        println!("   x={}: log P(x) = {:.3}", obs, log_p);
    }

    println!("๐Ÿ”ข Joint log-probability: {:.3}", total_log_prob);

    // Convert back to probability (be careful with underflow!)
    if total_log_prob > -700.0 {
        // Avoid underflow
        let probability = total_log_prob.exp();
        println!("๐Ÿ“Š Joint probability: {:.2e}", probability);
    } else {
        println!("โš ๏ธ  Joint probability too small to represent as f64");
    }

Warning

Converting large negative log-probabilities back to regular probabilities can underflow to zero. Keep computations in log-space when possible.

Advanced Patterns

For complex modeling scenarios, see these patterns:

Hierarchical Models

    // Hierarchical prior structure
    let global_mean = Normal::new(0.0, 10.0).unwrap();
    let mu = global_mean.sample(&mut rng);

    let group_precision = Gamma::new(2.0, 0.5).unwrap();
    let tau = group_precision.sample(&mut rng);
    let sigma = (1.0 / tau).sqrt(); // Convert precision to std dev

    // Individual observations from hierarchical model
    let individual = Normal::new(mu, sigma).unwrap();
    let observation = individual.sample(&mut rng);

    println!("๐ŸŒ Global mean: {:.3}", mu);
    println!("๐Ÿ“Š Group std dev: {:.3}", sigma);
    println!("๐Ÿ‘ค Individual observation: {:.3}", observation);

Mixture Components

    // Mixture model components
    let mixture_weights = vec![0.6, 0.3, 0.1];
    let component_selector = Categorical::new(mixture_weights).unwrap();
    let selected_component: usize = component_selector.sample(&mut rng);

    // Different components
    let components = [
        Normal::new(-2.0, 0.5).unwrap(),
        Normal::new(0.0, 1.0).unwrap(),
        Normal::new(3.0, 0.8).unwrap(),
    ];

    let sample = components[selected_component].sample(&mut rng);
    println!(
        "๐ŸŽฏ Selected component {}: sample = {:.3}",
        selected_component, sample
    );

Conjugate Priors

    // Beta-Bernoulli conjugacy
    let prior_alpha = 2.0;
    let prior_beta = 8.0;
    let prior = Beta::new(prior_alpha, prior_beta).unwrap();
    let p: f64 = prior.sample(&mut rng);

    // Simulate some trials
    let trials = 20;
    let mut successes = 0;
    let bernoulli = Bernoulli::new(p).unwrap();

    for _ in 0..trials {
        if bernoulli.sample(&mut rng) {
            successes += 1;
        }
    }

    // Posterior parameters (conjugate update)
    let posterior_alpha = prior_alpha + successes as f64;
    let posterior_beta = prior_beta + (trials - successes) as f64;
    let posterior = Beta::new(posterior_alpha, posterior_beta).unwrap();
    let updated_p = posterior.sample(&mut rng);

    println!("๐ŸŽฒ Prior p: {:.3}", p);
    println!("๐Ÿ“ˆ Observed: {}/{} successes", successes, trials);
    println!("๐Ÿ”„ Posterior p: {:.3}", updated_p);

Testing Your Distributions

Always test distribution properties and parameter validation:

    #[test]
    fn test_distribution_properties() {
        let mut rng = thread_rng();

        // Test type safety
        let coin = Bernoulli::new(0.5).unwrap();
        let flip: bool = coin.sample(&mut rng);
        assert!(flip == true || flip == false); // Must be boolean

        // Test parameter validation
        assert!(Normal::new(0.0, -1.0).is_err());
        assert!(Beta::new(0.0, 1.0).is_err());
        assert!(Poisson::new(-1.0).is_err());

        // Test valid distributions
        let normal = Normal::new(0.0, 1.0).unwrap();
        let sample = normal.sample(&mut rng);
        let log_prob = normal.log_prob(&sample);
        assert!(log_prob.is_finite());
    }

Common Pitfalls

  1. Underflow in probability space: Use log-probabilities for accumulation
  2. Parameter validation: Check constructor errors, don't assume success
  3. Precision with counts: Use u64 return types directly, avoid f64 conversion
  4. Categorical indexing: Trust the usize returnโ€”it's guaranteed valid

Next Steps

The type-safe distribution system eliminates entire classes of runtime errors while making statistical code more readable and maintainable.

Building Complex Models

Fugue's compositional architecture is grounded in category theory and monadic structures, enabling the systematic construction of sophisticated probabilistic models through principled composition operators. This guide explores the mathematical foundations and practical applications of Fugue's macro system for building complex probabilistic programs.

Categorical Foundations

Fugue models form a monad with:

  • Unit: via pure()
  • Bind: via prob! macro
  • Composition: Satisfies associativity and unit laws

This categorical structure ensures that model composition is mathematically sound and computationally tractable.

Do-Notation with prob!

The prob! macro implements monadic do-notation for probabilistic computations, providing a natural syntax for sequential dependence. Formally, it translates:

into the monadic composition :

graph LR
    A[Mโ‚] -->|bind| B[ฮปx.Mโ‚‚โฝหฃโพ]
    B -->|bind| C[ฮปy.ฮทโฝfโฝหฃ'สธโพโพ]
    C --> D[Resultโฝแถปโพ]
    // Simple do-notation style probabilistic program
    let _simple_model = prob!(
        let x <- sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
        let y <- sample(addr!("y"), Normal::new(x, 0.5).unwrap());
        let sum = x + y;  // Regular variable assignment
        pure(sum)
    );
    println!("โœ… Created simple model with prob! macro");

Key Features:

  • <- for probabilistic binding (monadic bind)
  • = for regular variable assignment
  • pure() to lift deterministic values
  • Natural control flow without callback nesting

Tip

Use prob! when you need to chain multiple probabilistic operations. It's especially powerful for dependent sampling where later variables depend on earlier ones.

Vectorized Operations with plate!

The plate! macro implements plate notation from graphical models, representing conditionally independent replications. Given independent observations, plate notation expresses:

The computational graph shows the independence structure:

graph TB
    subgraph "Plate: i โˆˆ {1..N}"
        A[ฮธ] --> B1[xโ‚]
        A --> B2[xโ‚‚]
        A --> B3[...]
        A --> BN[xโ‚™]
    end

Conditional Independence

Plate notation encodes the conditional independence assumption: for . This factorization enables efficient likelihood computation and parallel processing.

    // Independent samples using plate notation
    let _vector_model = plate!(i in 0..5 => {
        sample(addr!("sample", i), Normal::new(0.0, 1.0).unwrap())
    });
    println!("โœ… Created vectorized model with {} samples", 5);

    // Plate with observations
    let observations = [1.2, -0.5, 2.1, 0.8, -1.0];
    let n_obs = observations.len();
    let _observed_model = plate!(i in 0..n_obs => {
        observe(addr!("obs", i), Normal::new(0.0, 1.0).unwrap(), observations[i])
    });
    println!("โœ… Created observation model for {} data points", n_obs);

Benefits:

  • Automatic address indexing prevents conflicts
  • Natural iteration over data structures
  • Vectorized likelihood computations
  • Clear intent for independent operations

Note

The plate! macro automatically appends indices to addresses, so addr!("sample", i) becomes unique for each iteration without manual address management.

Hierarchical Address Management

Complex models require systematic parameter organization following a tree-structured address space. The address hierarchy forms a prefix tree where each node represents a scope:

This hierarchical structure prevents address collisions and enables efficient parameter lookup:

Address Space Theory

The hierarchical address space forms a partially ordered set where if is a prefix of . This structure ensures unique identification of parameters while maintaining compositional semantics.

    // Hierarchical model using scoped addresses
    let _hierarchical_model = prob!(
        let global_mu <- sample(addr!("global_mu"), Normal::new(0.0, 10.0).unwrap());
        let group_mu <- sample(scoped_addr!("group", "mu", "{}", 0),
                              Normal::new(global_mu, 1.0).unwrap());
        pure((global_mu, group_mu))
    );
    println!("โœ… Created hierarchical model with scoped addresses");

Address Strategy:

  • scoped_addr! prevents parameter name collisions
  • Hierarchical structure mirrors model dependencies
  • Systematic naming aids debugging and introspection
  • Indices enable parameter arrays

Sequential Dependencies

Sequential models exhibit temporal dependence where the state at time depends on previous states. This creates a Markov chain structure:

The computational challenge lies in maintaining state consistency while enabling efficient inference:

    // Sequential model with dependencies
    let _sequential_model = prob! {
        let states <- plate!(t in 0..3 => {
            sample(addr!("x", t), Normal::new(0.0, 1.0).unwrap())
                .bind(move |x_t| {
                    observe(addr!("y", t), Normal::new(x_t, 0.5).unwrap(), 1.0 + t as f64)
                        .map(move |_| x_t)
                })
        });

        pure(states)
    };
    println!("โœ… Created sequential model with observations");

Patterns:

  • Explicit state threading through computations
  • Observation conditioning at each time step
  • Autoregressive dependencies
  • Mixed probabilistic and deterministic updates

Warning

Sequential models can create large traces. Consider using memory-efficient handlers for long sequences.

Composable Model Functions

Build reusable model components:

    // Helper function to create a component model
    fn create_normal_component(name: &str, mean: f64, std: f64) -> Model<f64> {
        sample(addr!(name), Normal::new(mean, std).unwrap())
    }

    // Compose multiple components
    let _composition_model = prob! {
        let param1 <- create_normal_component("param1", 0.0, 1.0);
        let param2 <- create_normal_component("param2", 2.0, 0.5);
        let combined = param1 * param2;
        pure(combined)
    };
    println!("โœ… Created composed model with reusable components");

Design Principles:

  • Functions return Model<T> for composability
  • Pattern matching enables model selection
  • Pure functions for deterministic transformations
  • Higher-order functions for model templates

Advanced Address Patterns

For large-scale models like neural networks:

    // Complex addressing for large models
    let _neural_layer_model = plate!(layer in 0..3 => {
        let layer_size = match layer {
            0 => 4,
            1 => 8,
            2 => 1,
            _ => 1,
        };

        plate!(i in 0..layer_size => {
            sample(
                scoped_addr!("layer", "weight", "{}_{}", layer, i),
                Normal::new(0.0, 0.1).unwrap()
            )
        })
    });
    println!("โœ… Created neural network parameter structure");

Scaling Strategies:

  • Systematic parameter naming conventions
  • Multi-level scoping for complex architectures
  • Consistent indexing schemes
  • Hierarchical parameter organization

Mixing Styles for Flexibility

Combine macros with traditional function composition:

    // Mixture model with component selection
    let _mixture_model = prob! {
        let component <- sample(addr!("component"), Bernoulli::new(0.3).unwrap());
        let mu = if component { -2.0 } else { 2.0 };
        let x <- sample(addr!("x"), Normal::new(mu, 1.0).unwrap());
        pure((component, x))
    };
    println!("โœ… Created mixture model with 2 components");

Best Practices:

  • Use functions for reusable components
  • Use macros for readable composition
  • Separate concerns (priors, likelihood, observations)
  • Document parameter dependencies

Real-World Applications

Bayesian Linear Regression

Bayesian linear regression models the relationship with uncertainty quantification:

    // Complete Bayesian linear regression
    let x_data = [1.0, 2.0, 3.0, 4.0, 5.0];
    let y_data = [2.1, 3.9, 6.2, 8.1, 9.8];
    let n = x_data.len();

    let _regression_model = prob! {
        let intercept <- sample(addr!("intercept"), Normal::new(0.0, 10.0).unwrap());
        let slope <- sample(addr!("slope"), Normal::new(0.0, 10.0).unwrap());
        let precision <- sample(addr!("precision"), Gamma::new(1.0, 1.0).unwrap());
        let sigma = (1.0 / precision).sqrt();

        let _likelihood <- plate!(i in 0..n => {
            let predicted = intercept + slope * x_data[i];
            observe(addr!("y", i), Normal::new(predicted, sigma).unwrap(), y_data[i])
        });

        pure((intercept, slope, sigma))
    };
    println!("โœ… Created Bayesian linear regression model");

Hierarchical Clustering

Hierarchical models implement partial pooling through multi-level parameter structures. The hierarchy enables information sharing across groups while maintaining group-specific effects:

    // Simplified hierarchy to avoid nested macro issues
    let _multilevel_model = prob!(
        let pop_mean <- sample(addr!("pop_mean"), Normal::new(0.0, 10.0).unwrap());
        let _pop_precision <- sample(addr!("pop_precision"), Gamma::new(2.0, 0.5).unwrap());
        let group_mean <- sample(scoped_addr!("group", "mean", "{}", 0),
                                Normal::new(pop_mean, 1.0).unwrap());
        pure((pop_mean, group_mean))
    );
    println!("โœ… Created hierarchical model structure");

State Space Models

Sequential latent variable models:

    // Sequential model with dependencies
    let _sequential_model = prob! {
        let states <- plate!(t in 0..3 => {
            sample(addr!("x", t), Normal::new(0.0, 1.0).unwrap())
                .bind(move |x_t| {
                    observe(addr!("y", t), Normal::new(x_t, 0.5).unwrap(), 1.0 + t as f64)
                        .map(move |_| x_t)
                })
        });

        pure(states)
    };
    println!("โœ… Created sequential model with observations");

Multi-Level Hierarchies

Population โ†’ Groups โ†’ Individuals structure:

    // Simplified hierarchy to avoid nested macro issues
    let _multilevel_model = prob!(
        let pop_mean <- sample(addr!("pop_mean"), Normal::new(0.0, 10.0).unwrap());
        let _pop_precision <- sample(addr!("pop_precision"), Gamma::new(2.0, 0.5).unwrap());
        let group_mean <- sample(scoped_addr!("group", "mean", "{}", 0),
                                Normal::new(pop_mean, 1.0).unwrap());
        pure((pop_mean, group_mean))
    );
    println!("โœ… Created hierarchical model structure");

Key Features:

  • Partial pooling across hierarchy levels
  • Systematic parameter organization
  • Natural shrinkage properties
  • Scalable to large group structures

Configurable Model Factories

Dynamic model construction:

    // Helper function to create a component model
    fn create_normal_component(name: &str, mean: f64, std: f64) -> Model<f64> {
        sample(addr!(name), Normal::new(mean, std).unwrap())
    }

    // Compose multiple components
    let _composition_model = prob! {
        let param1 <- create_normal_component("param1", 0.0, 1.0);
        let param2 <- create_normal_component("param2", 2.0, 0.5);
        let combined = param1 * param2;
        pure(combined)
    };
    println!("โœ… Created composed model with reusable components");

Flexibility Benefits:

  • Runtime model configuration
  • Conditional model components
  • A/B testing different model structures
  • Experiment management

Testing Complex Models

Model validation requires systematic testing across multiple dimensions: syntactic correctness, semantic validity, and statistical consistency:

graph TD
    A[Model M] --> B[Syntactic Tests]
    A --> C[Semantic Tests]
    A --> D[Statistical Tests]

    B --> E[Type Checking]
    B --> F[Address Uniqueness]

    C --> G[Trace Validity]
    C --> H[Parameter Bounds]

    D --> I[Prior Predictive]
    D --> J[Posterior Consistency]

    E --> K{All Pass?}
    F --> K
    G --> K
    H --> K
    I --> K
    J --> K

    K -->|Yes| L[Model Validated]
    K -->|No| M[Refinement Required]

Testing Hierarchy:

  1. Unit Tests: Individual model components
  2. Integration Tests: Model composition correctness
  3. Statistical Tests: Distributional properties
  4. Performance Tests: Scalability and efficiency
    #[test]
    fn test_model_composition() {
        // Test that models construct without errors
        let _simple = prob! {
            let x <- sample(addr!("test_x"), Normal::new(0.0, 1.0).unwrap());
            pure(x)
        };

        // Test plate notation
        let _plate_model = plate!(i in 0..3 => {
            sample(addr!("plate_test", i), Normal::new(0.0, 1.0).unwrap())
        });

        // Test scoped addresses
        let addr1 = scoped_addr!("test", "param");
        let addr2 = scoped_addr!("test", "param", "{}", 42);

        // Addresses should be different
        assert_ne!(addr1.0, addr2.0);
        assert!(addr2.0.contains("42"));

        // Test hierarchical model construction
        let _hierarchical = prob! {
            let global <- sample(addr!("global"), Normal::new(0.0, 1.0).unwrap());
            let locals <- plate!(i in 0..2 => {
                sample(scoped_addr!("local", "param", "{}", i),
                       Normal::new(global, 0.1).unwrap())
            });
            pure((global, locals))
        };

        // All models should construct successfully
        // (Actual execution would require handlers)
    }

Common Pitfalls

  1. Address Conflicts: Use scoped_addr! for complex models
  2. Memory Usage: Large plate operations can create big traces
  3. Sequential Dependencies: Explicit state management required
  4. Type Inference: Sometimes need explicit type annotations

Performance Considerations

  • Plate Size: Very large plates may exceed memory limits
  • Nesting Depth: Deep hierarchies increase trace size
  • Address Complexity: Simple addresses are more efficient
  • Function Composition: Pure functions are optimized away

Next Steps

Compositional Excellence

Building complex models successfully combines mathematical rigor with practical implementation:

  1. Categorical Foundations: Monadic structure ensures compositionality
  2. Systematic Organization: Hierarchical addressing prevents conflicts
  3. Efficient Computation: Plate notation enables vectorization
  4. Validation Framework: Multi-level testing ensures correctness

These patterns transform complex probabilistic modeling from ad-hoc construction into principled composition.

Complex models become tractable and maintainable through systematic composition, principled addressing, and mathematical abstraction. Fugue's macro system provides elegant syntactic sugar while preserving the underlying categorical structure that enables powerful inference algorithms and compositional reasoning about probabilistic programs.

Optimizing Performance

Performance optimization in probabilistic programming requires understanding both computational complexity and numerical analysis. This guide explores Fugue's systematic approach to memory optimization, numerical stability, and algorithmic efficiency for production-scale probabilistic workloads.

Computational Complexity Framework

Probabilistic programs exhibit multi-dimensional complexity:

  • Time complexity: for samples, parameters, iterations
  • Space complexity: with memory pooling
  • Numerical complexity: Condition number affects convergence

Fugue's optimization framework addresses each dimension systematically.

Memory-Optimized Inference

Memory allocation becomes the computational bottleneck in high-throughput scenarios due to garbage collection overhead. The allocation rate for naive inference scales as:

where is the sample count, is the trace size, and is the GC frequency. Fugue's object pooling reduces this to after warmup:

graph TD
    subgraph "Traditional Allocation"
        A1["Sample 1"] --> B1["Allocate Trace"]
        B1 --> C1["GC Pressure"]
        A2["Sample 2"] --> B2["Allocate Trace"]  
        B2 --> C2["GC Pressure"]
        A3["Sample n"] --> B3["Allocate Trace"]
        B3 --> C3["GC Pressure"]
    end
    
    subgraph "Pooled Allocation"  
        D1["Sample 1"] --> E1["Reuse from Pool"]
        D2["Sample 2"] --> E1
        D3["Sample n"] --> E1
        E1 --> F["Zero GC Pressure"]
    end
    // Create trace pool for zero-allocation inference
    let mut pool = TracePool::new(50); // Pool up to 50 traces
    let mut rng = thread_rng();

    // Define a model that would normally cause many allocations
    let make_model = || {
        prob!(
            let x <- sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
            let y <- sample(addr!("y"), Normal::new(x, 0.5).unwrap());
            observe(addr!("obs"), Normal::new(y, 0.1).unwrap(), 1.5);
            pure(x)
        )
    };

    // Time pooled vs non-pooled execution
    let start = Instant::now();
    for _iteration in 0..1000 {
        // Use pooled handler for efficient memory reuse
        let (_result, trace) =
            runtime::handler::run(PooledPriorHandler::new(&mut rng, &mut pool), make_model());
        // Return trace to pool for reuse
        pool.return_trace(trace);
    }
    let pooled_time = start.elapsed();

    let stats = pool.stats();
    println!("โœ… Completed 1000 iterations with memory pooling");
    println!("   - Execution time: {:?}", pooled_time);
    println!("   - Hit ratio: {:.1}%", stats.hit_ratio());
    println!(
        "   - Pool stats - hits: {}, misses: {}",
        stats.hits, stats.misses
    );

Key Benefits:

  • Zero-allocation execution after warm-up
  • Configurable pool size for memory control
  • Automatic trace recycling and cleanup
  • Built-in performance monitoring with hit ratios

Numerical Stability

Numerical stability in probabilistic computing requires careful analysis of condition numbers and floating-point precision. The log-sum-exp operation is fundamental:

Stability Analysis: Direct computation of has condition number , which becomes ill-conditioned when .

Catastrophic Cancellation

When are large and similar, direct computation suffers from catastrophic cancellation: The LSE formulation maintains relative precision regardless of scale.

    // Demonstrate stable log-probability computations
    let extreme_log_probs = vec![700.0, 701.0, 699.0, 698.0]; // Would overflow in linear space

    // Safe log-sum-exp prevents overflow
    let log_normalizer = log_sum_exp(&extreme_log_probs);
    let normalized_probs = normalize_log_probs(&extreme_log_probs);

    println!("โœ… Stable computation with extreme log-probabilities");
    println!("   - Log normalizer: {:.2}", log_normalizer);
    println!(
        "   - Probabilities sum to: {:.10}",
        normalized_probs.iter().sum::<f64>()
    );

    // Weighted log-sum-exp for importance sampling
    let log_values = vec![-1.0, -2.0, -3.0, -4.0];
    let weights = vec![0.4, 0.3, 0.2, 0.1];
    let weighted_result = weighted_log_sum_exp(&log_values, &weights);

    println!("   - Weighted log-sum-exp: {:.4}", weighted_result);

    // Safe logarithm handling
    let safe_results: Vec<f64> = [1.0, 0.0, -1.0].iter().map(|&x| safe_ln(x)).collect();
    println!("   - Safe ln results: {:?}", safe_results);

Stability Features:

  • log_sum_exp prevents overflow in mixture computations
  • weighted_log_sum_exp for importance sampling
  • safe_ln handles edge cases gracefully
  • All operations maintain numerical precision across scales

Efficient Trace Construction

When building traces programmatically, use TraceBuilder for optimal performance:

    // Use TraceBuilder for efficient trace creation
    let mut builder = TraceBuilder::new();

    let start = Instant::now();
    for i in 0..100 {
        // Add choices efficiently without reallocations
        builder.add_sample(
            addr!("param", i),
            i as f64,
            0.0, // log_prob
        );
    }

    // Build final trace efficiently
    let constructed_trace = builder.build();
    let construction_time = start.elapsed();

    println!("โœ… Efficient trace construction");
    println!(
        "   - Built trace with {} choices in {:?}",
        constructed_trace.choices.len(),
        construction_time
    );
    println!(
        "   - Total log weight: {:.2}",
        constructed_trace.total_log_weight()
    );

Construction Benefits:

  • Pre-allocated data structures minimize reallocations
  • Type-specific insertion methods avoid boxing overhead
  • Batch operations for multiple choices
  • Efficient conversion to immutable traces

Copy-on-Write for MCMC

MCMC algorithms exhibit temporal locality in parameter updates, modifying only parameters per iteration where is the total dimensionality. Copy-on-Write (COW) data structures exploit this pattern:

graph TD
    subgraph "MCMC Iteration Structure"
        A["Base Trace Tโ‚€"] --> B{"Proposal Step"}
        B --> C["Modified Parameters ฮด"]
        C --> D{"Small Changes?"}
        D -->|Yes| E["COW: Share + ฮ”"]
        D -->|No| F["Full Copy"]
        E --> G["O(1) Memory"]
        F --> H["O(d) Memory"]
    end

Complexity Analysis: Traditional MCMC requires space per sample. COW reduces this to where is the edit distance between traces.

    // Create base trace manually for MCMC
    let mut builder = TraceBuilder::new();
    builder.add_sample(addr!("mu"), 0.5, -0.5);
    builder.add_sample(addr!("sigma"), 1.0, -1.0);
    builder.add_sample_bool(addr!("component"), true, -0.69);
    let base_trace = builder.build();

    // Create COW trace for efficient copying
    let cow_base = CowTrace::from_trace(base_trace);

    let start = Instant::now();
    let mut mcmc_traces = Vec::new();

    for _proposal in 0..1000 {
        // Clone is O(1) until modification
        let mut proposal_trace = cow_base.clone();

        // Modify only one parameter (triggers COW)
        proposal_trace.insert_choice(
            addr!("mu"),
            Choice {
                addr: addr!("mu"),
                value: ChoiceValue::F64(0.6),
                logp: -0.4,
            },
        );

        mcmc_traces.push(proposal_trace);
    }
    let cow_time = start.elapsed();

    println!("โœ… Copy-on-write MCMC proposals");
    println!("   - Created 1000 proposal traces in {:?}", cow_time);
    println!("   - Memory sharing until modification");

MCMC Optimizations:

  • O(1) trace cloning until modification
  • Shared memory for unchanged parameters
  • Lazy copying only when traces diverge
  • Perfect for Metropolis-Hastings and Gibbs sampling

Vectorized Model Patterns

Structure models for efficient batch processing:

    // Pre-allocate data structures for repeated use
    let observations: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
    let n = observations.len();

    // Efficient vectorized model
    let vectorized_model = || {
        prob!(
            let mu <- sample(addr!("global_mu"), Normal::new(0.0, 10.0).unwrap());
            let precision <- sample(addr!("precision"), Gamma::new(2.0, 1.0).unwrap());
            let sigma = (1.0 / precision).sqrt();

            // Use plate for efficient vectorized operations
            let _likelihoods <- plate!(i in 0..n => {
                observe(addr!("obs", i), Normal::new(mu, sigma).unwrap(), observations[i])
            });

            pure((mu, sigma))
        )
    };

    let start = Instant::now();
    let (_result, _trace) = runtime::handler::run(
        PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        vectorized_model(),
    );
    let vectorized_time = start.elapsed();

    println!("โœ… Optimized vectorized model");
    println!("   - Processed {} observations in {:?}", n, vectorized_time);

Vectorization Strategy:

  • Pre-allocate data collections
  • Use plate! for independent parallel operations
  • Minimize dynamic allocations in hot paths
  • Leverage compiler optimizations with static sizing

Performance Monitoring

Systematic performance monitoring requires tracking multiple performance metrics with their theoretical bounds:

where is the theoretical minimum execution time per sample.

Amdahl's Law for MCMC

Even with perfect parallelization, MCMC exhibits sequential dependencies that limit speedup: where is the fraction of sequential computation and is the number of processors.

    // Monitor trace characteristics for optimization insights
    #[derive(Debug)]
    struct TraceMetrics {
        num_choices: usize,
        log_weight: f64,
        is_valid: bool,
        memory_size_estimate: usize,
    }

    impl TraceMetrics {
        fn from_trace(trace: &Trace) -> Self {
            let num_choices = trace.choices.len();
            let log_weight = trace.total_log_weight();
            let is_valid = log_weight.is_finite();

            // Rough memory estimate (actual implementation would be more precise)
            let memory_size_estimate = num_choices * 64; // Rough bytes per choice

            Self {
                num_choices,
                log_weight,
                is_valid,
                memory_size_estimate,
            }
        }
    }

    // Example: Monitor a complex model's performance
    let complex_model = || {
        prob!(
            let components <- plate!(c in 0..5 => {
                sample(addr!("weight", c), Gamma::new(1.0, 1.0).unwrap())
                    .bind(move |weight| {
                        sample(addr!("mu", c), Normal::new(0.0, 2.0).unwrap())
                            .map(move |mu| (weight, mu))
                    })
            });

            let selector <- sample(addr!("selector"),
                                  Categorical::new(vec![0.2, 0.2, 0.2, 0.2, 0.2]).unwrap());

            pure((components, selector))
        )
    };

    let (_result, trace) = runtime::handler::run(
        PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        complex_model(),
    );

    let metrics = TraceMetrics::from_trace(&trace);
    println!("โœ… Performance monitoring active");
    println!("   - Trace choices: {}", metrics.num_choices);
    println!("   - Log weight: {:.2}", metrics.log_weight);
    println!("   - Valid: {}", metrics.is_valid);
    println!(
        "   - Memory estimate: {} bytes",
        metrics.memory_size_estimate
    );

Monitoring Approach:

  • Collect trace characteristics for optimization insights
  • Track memory usage patterns
  • Validate numerical stability
  • Profile execution bottlenecks

Batch Processing

Batch processing amortizes setup costs and exploits hardware parallelism. The optimal batch size balances memory usage and throughput:

where:

  • is the per-batch initialization cost
  • is the per-sample memory cost
  • is the synchronization overhead
graph LR
    subgraph "Performance vs Batch Size"
        A["Small Batches<br/>b โ†’ 1"] --> B["High Setup<br/>Overhead"]
        C["Large Batches<br/>b โ†’ โˆž"] --> D["Memory<br/>Pressure"]  
        E["Optimal Batch<br/>b*"] --> F["Balanced<br/>Performance"]
    end
    // Efficient batch inference using memory pooling
    let batch_size = 100;
    let mut batch_pool = TracePool::new(batch_size);

    let start = Instant::now();
    let mut batch_results = Vec::with_capacity(batch_size);

    for _batch in 0..batch_size {
        let (result, trace) = runtime::handler::run(
            PooledPriorHandler::new(&mut rng, &mut batch_pool),
            make_model(),
        );
        // Return trace to pool for reuse
        batch_pool.return_trace(trace);
        batch_results.push(result);
    }

    let batch_time = start.elapsed();
    let batch_stats = batch_pool.stats();

    println!("โœ… Batch processing complete");
    println!("   - Processed {} samples in {:?}", batch_size, batch_time);
    println!(
        "   - Average time per sample: {:?}",
        batch_time / batch_size as u32
    );
    println!(
        "   - Memory efficiency: {:.1}% hit ratio",
        batch_stats.hit_ratio()
    );

Batch Benefits:

  • Amortized setup costs across samples
  • Memory pool reuse for consistent performance
  • Scalable to large sample counts
  • Predictable memory footprint

Numerical Precision Testing

Validate stability across different computational scales:

    // Test numerical stability across different scales
    let test_scales = vec![1e-10, 1e-5, 1.0, 1e5, 1e10];

    for &scale in &test_scales {
        let scale: f64 = scale;
        let log_vals = vec![scale.ln() + 1.0, scale.ln() + 2.0, scale.ln() + 0.5];

        let stable_sum = log_sum_exp(&log_vals);
        let log1p_result = log1p_exp(scale.ln());

        println!(
            "   Scale {:.0e}: log_sum_exp={:.4}, log1p_exp={:.4}",
            scale, stable_sum, log1p_result
        );
    }
    println!("โœ… Numerical stability verified across scales");

Testing Strategy:

  • Verify stability across extreme value ranges
  • Test edge cases and boundary conditions
  • Validate consistency of numerical operations
  • Profile precision vs. performance trade-offs

Performance Testing

Implement systematic performance validation:

    #[test]
    fn test_memory_pool_efficiency() {
        let mut pool = TracePool::new(10);
        let mut rng = thread_rng();

        // Test pool reuse with PooledPriorHandler
        for _i in 0..20 {
            let (_, trace) = runtime::handler::run(
                PooledPriorHandler::new(&mut rng, &mut pool),
                sample(addr!("test"), Normal::new(0.0, 1.0).unwrap()),
            );
            // Return trace to pool for reuse
            pool.return_trace(trace);
        }

        let stats = pool.stats();
        assert!(
            stats.hit_ratio() > 50.0,
            "Pool should have good hit ratio, got {:.1}%",
            stats.hit_ratio()
        );
        assert!(stats.hits + stats.misses > 0, "Pool should have been used");
    }

    #[test]
    fn test_numerical_stability() {
        // Test log_sum_exp with extreme values
        let extreme_vals = vec![700.0, 701.0, 699.0];
        let result = log_sum_exp(&extreme_vals);
        assert!(
            result.is_finite(),
            "log_sum_exp should handle extreme values"
        );

        // Test normalization
        let normalized = normalize_log_probs(&extreme_vals);
        let sum: f64 = normalized.iter().sum();
        assert!(
            (sum - 1.0).abs() < 1e-10,
            "Normalized probabilities should sum to 1"
        );

        // Test weighted computation
        let weights = vec![0.5, 0.3, 0.2];
        let weighted_result = weighted_log_sum_exp(&extreme_vals, &weights);
        assert!(
            weighted_result.is_finite(),
            "Weighted log_sum_exp should be finite"
        );
    }

    #[test]
    fn test_trace_builder_efficiency() {
        let mut builder = TraceBuilder::new();

        // Add many choices efficiently
        for i in 0..100 {
            builder.add_sample(addr!("param", i), i as f64, -0.5);
        }

        let trace = builder.build();
        assert_eq!(trace.choices.len(), 100);
        assert!(trace.total_log_weight().is_finite());
    }

    #[test]
    fn test_cow_trace_sharing() {
        // Create base trace using builder
        let mut builder = TraceBuilder::new();
        builder.add_sample(addr!("x"), 1.0, -0.5);
        let base = builder.build();
        let cow_trace = CowTrace::from_trace(base);

        // Clone should be fast
        let clone1 = cow_trace.clone();
        let clone2 = cow_trace.clone();

        // Should share data until modification - convert to regular trace to test
        let trace1 = clone1.to_trace();
        let trace2 = clone2.to_trace();
        assert_eq!(trace1.get_f64(&addr!("x")), Some(1.0));
        assert_eq!(trace2.get_f64(&addr!("x")), Some(1.0));
    }

Testing Framework:

  • Memory pool efficiency validation
  • Numerical stability regression tests
  • Trace construction benchmarking
  • COW sharing verification

Production Deployment

Memory Configuration

  • Size TracePool based on peak concurrent inference
  • Monitor hit ratios to validate pool efficiency
  • Use COW traces for MCMC workloads
  • Pre-warm pools before production traffic

Numerical Strategies

  • Always use log-space for probability computations
  • Validate extreme value handling in testing
  • Monitor for numerical instabilities in production
  • Use stable algorithms for critical computations

Monitoring and Alerting

  • Track inference latency and memory usage
  • Monitor pool statistics and efficiency metrics
  • Alert on numerical instabilities or performance degradation
  • Profile hot paths for optimization opportunities

Common Performance Patterns

  1. Pool First: Use TracePool for any repeated inference
  2. Log Always: Work in log-space for numerical stability
  3. Batch Everything: Amortize costs across multiple samples
  4. Monitor Continuously: Track performance metrics in production
  5. Test Extremes: Validate stability with extreme values

These optimization strategies enable Fugue to handle production-scale probabilistic programming workloads with consistent performance and numerical reliability.

Debugging Models

Debugging probabilistic models presents unique challenges due to their stochastic nature and high-dimensional parameter spaces. Unlike deterministic programs, probabilistic models require statistical validation, convergence analysis, and distributional testing. This guide establishes a systematic methodology for probabilistic model debugging using Fugue's comprehensive diagnostic framework.

Probabilistic Debugging Theory

Model debugging operates on multiple abstraction levels:

  • Syntactic: Code structure and type correctness
  • Semantic: Model specification and parameter validity
  • Statistical: Distributional properties and moment consistency
  • Computational: Numerical stability and convergence behavior

Each level requires specialized diagnostic techniques and validation criteria.

Trace Inspection and Analysis

Execution traces form the foundation of probabilistic model debugging. Each trace contains a complete record of the program's stochastic execution:

where is the address, is the sampled value, and is the log-weight contribution.

graph TD
    subgraph "Trace Analysis Workflow"
        A[Execution Trace T] --> B{Finite Log-Weight?}
        B -->|No| C[Numerical Instability]
        B -->|Yes| D[Choice Analysis]
        D --> E[Parameter Extraction]
        E --> F[Statistical Validation]
        F --> G{Passes Tests?}
        G -->|No| H[Model Refinement]
        G -->|Yes| I[Model Validated]
        C --> J[Debug Constraints]
        H --> A
        J --> A
    end

Mathematical Properties: A valid trace must satisfy the weight consistency condition:

    // Execute a model and examine its trace structure
    let mut rng = thread_rng();

    let diagnostic_model = || {
        prob!(
            let mu <- sample(addr!("mu"), Normal::new(0.0, 2.0).unwrap());
            let sigma <- sample(addr!("sigma"), Gamma::new(2.0, 1.0).unwrap());
            observe(addr!("obs1"), Normal::new(mu, sigma).unwrap(), 1.5);
            observe(addr!("obs2"), Normal::new(mu, sigma).unwrap(), 1.2);
            factor(if mu.abs() < 3.0 { 0.0 } else { f64::NEG_INFINITY });
            pure((mu, sigma))
        )
    };

    let ((mu_val, sigma_val), trace) = runtime::handler::run(
        PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        diagnostic_model(),
    );

    println!("โœ… Model execution complete");
    println!("   - Result: mu = {:.3}, sigma = {:.3}", mu_val, sigma_val);
    println!("   - Choices recorded: {}", trace.choices.len());
    println!("   - Prior log-weight: {:.6}", trace.log_prior);
    println!("   - Likelihood log-weight: {:.6}", trace.log_likelihood);
    println!("   - Factor log-weight: {:.6}", trace.log_factors);
    println!("   - Total log-weight: {:.6}", trace.total_log_weight());

    // Per-choice breakdown
    println!("   - Choice breakdown:");
    for (addr, choice) in &trace.choices {
        println!(
            "     {}: {:?} (logp: {:.6})",
            addr, choice.value, choice.logp
        );
    }

Key Debugging Insights:

  • Choice count reveals model complexity and structure
  • Log-weight decomposition identifies prior vs. likelihood vs. factor issues
  • Per-choice analysis shows individual parameter contributions
  • Finite log-weights indicate valid model execution

Type-Safe Value Access

Fugue provides robust access patterns that handle type mismatches gracefully:

    // Safe access patterns that handle type mismatches gracefully

    // Option-based access (returns None on mismatch)
    match trace.get_f64(&addr!("mu")) {
        Some(mu) => println!("โœ… Retrieved mu = {:.3}", mu),
        None => println!("โŒ Failed to get mu as f64"),
    }

    // Result-based access (returns detailed error info)
    match trace.get_f64_result(&addr!("sigma")) {
        Ok(sigma) => println!("โœ… Retrieved sigma = {:.3}", sigma),
        Err(e) => println!("โŒ Error getting sigma: {}", e),
    }

    // Handle missing addresses
    match trace.get_f64_result(&addr!("missing_param")) {
        Ok(_) => unreachable!(),
        Err(e) => println!("โœ… Correctly caught missing address: {}", e),
    }

    // Handle type mismatches
    match trace.get_bool_result(&addr!("mu")) {
        Ok(_) => unreachable!(),
        Err(e) => println!("โœ… Correctly caught type mismatch: {}", e),
    }

    // Iterate through all choices for debugging
    println!("   - All choices and their types:");
    for (addr, choice) in &trace.choices {
        let type_info = match &choice.value {
            ChoiceValue::F64(_) => "f64",
            ChoiceValue::Bool(_) => "bool",
            ChoiceValue::U64(_) => "u64",
            ChoiceValue::I64(_) => "i64",
            ChoiceValue::Usize(_) => "usize",
        };
        println!("     {} ({}): {:?}", addr, type_info, choice.value);
    }

Error Handling Strategies:

  • Use get_*_result() for detailed error information
  • Use get_*() for simple None-handling
  • Always check for missing addresses before assuming success
  • Iterate through all choices to understand model structure

Model Validation and Testing

Systematic validation ensures your model behaves as expected:

    // Test a simple conjugate model against analytical solution
    let conjugate_model = || {
        prob!(
            let theta <- sample(addr!("theta"), Beta::new(1.0, 1.0).unwrap());
            observe(addr!("successes"), Binomial::new(10, theta).unwrap(), 7u64);
            pure(theta)
        )
    };

    // Run a few samples to test basic functionality
    let mut theta_samples = Vec::new();
    for _ in 0..20 {
        let (theta, test_trace) = runtime::handler::run(
            PriorHandler {
                rng: &mut rng,
                trace: Trace::default(),
            },
            conjugate_model(),
        );

        // Validate trace structure
        assert!(test_trace.choices.contains_key(&addr!("theta")));
        assert!(
            test_trace.total_log_weight().is_finite(),
            "Trace should have finite log-weight"
        );
        assert!(
            test_trace.log_likelihood.is_finite(),
            "Likelihood should be finite"
        );

        theta_samples.push(theta);
    }

    // Basic statistical checks
    let sample_mean = theta_samples.iter().sum::<f64>() / theta_samples.len() as f64;
    println!("โœ… Validation tests passed");
    println!("   - Generated {} samples", theta_samples.len());
    println!(
        "   - Sample mean: {:.3} (expected ~0.7 for Beta-Binomial)",
        sample_mean
    );
    println!("   - All traces had finite log-weights");

Validation Best Practices:

  • Test against known analytical solutions
  • Verify all traces have finite log-weights
  • Check basic statistical properties (means, variances)
  • Test edge cases and boundary conditions

Safe vs Strict Error Handling

Fugue provides both strict (fail-fast) and safe (error-resilient) execution modes:

    // Create a trace with known structure for replay testing
    let mut base_trace = Trace::default();
    base_trace.insert_choice(addr!("param"), ChoiceValue::F64(1.5), -0.5);

    let test_model = || sample(addr!("param"), Normal::new(0.0, 1.0).unwrap());

    // Strict replay - will panic on mismatch (commented out for safety)
    // let strict_replay = ReplayHandler { base_trace: &base_trace };
    // let (strict_result, strict_trace) = runtime::handler::run(strict_replay, test_model());

    // Safe replay - handles errors gracefully
    let safe_replay = SafeReplayHandler {
        rng: &mut rng,
        base: base_trace.clone(),
        trace: Trace::default(),
        warn_on_mismatch: true,
    };
    let (safe_result, safe_trace) = runtime::handler::run(safe_replay, test_model());

    println!("โœ… Safe replay succeeded");
    println!("   - Result: {:.3}", safe_result);
    println!(
        "   - Retrieved value: {:?}",
        safe_trace.get_f64(&addr!("param"))
    );

    // Test scoring with safe handler
    let safe_score = SafeScoreGivenTrace {
        base: base_trace,
        trace: Trace::default(),
        warn_on_error: false,
    };
    let (_, score_trace) = runtime::handler::run(safe_score, test_model());

    println!(
        "   - Score trace log-weight: {:.3}",
        score_trace.total_log_weight()
    );

When to Use Each:

  • Strict handlers (ReplayHandler, ScoreGivenTrace): Development and testing
  • Safe handlers (SafeReplayHandler, SafeScoreGivenTrace): Production systems
  • Safe handlers log warnings instead of panicking on mismatches

MCMC Diagnostics

Markov Chain Monte Carlo convergence assessment requires statistical hypothesis testing and diagnostic metrics. The fundamental question is whether the chain has reached its stationary distribution .

Gelman-Rubin Diagnostic

The potential scale reduction factor compares within-chain and between-chain variance:

where:

  • (within-chain variance)
  • (between-chain variance)
  • (marginal posterior variance estimate)

Convergence Criterion

Theoretical Result: As , if the chain has converged, then . Practical Threshold: indicates approximate convergence for most applications. Statistical Interpretation: suggests the chain hasn't explored the full posterior distribution.

Effective Sample Size

The effective sample size accounts for autocorrelation in MCMC samples:

where is the lag- autocorrelation and is the total number of samples.

    // Generate simple MCMC chains for diagnostic testing
    let mcmc_model = || {
        prob!(
            let mu <- sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap());
            observe(addr!("y"), Normal::new(mu, 0.5).unwrap(), 1.0);
            pure(mu)
        )
    };

    // Generate two short chains for R-hat calculation
    let n_samples = 50;
    let n_warmup = 10;

    let mut chain1_samples = Vec::new();
    let mut chain2_samples = Vec::new();

    // Chain 1
    let mut rng1 = rand::rngs::StdRng::seed_from_u64(42);
    let chain1 = adaptive_mcmc_chain(&mut rng1, mcmc_model, n_samples, n_warmup);
    for (_, trace) in &chain1 {
        if let Some(mu) = trace.get_f64(&addr!("mu")) {
            chain1_samples.push(mu);
        }
    }

    // Chain 2
    let mut rng2 = rand::rngs::StdRng::seed_from_u64(123);
    let chain2 = adaptive_mcmc_chain(&mut rng2, mcmc_model, n_samples, n_warmup);
    for (_, trace) in &chain2 {
        if let Some(mu) = trace.get_f64(&addr!("mu")) {
            chain2_samples.push(mu);
        }
    }

    // Compute diagnostics
    if !chain1_samples.is_empty() && !chain2_samples.is_empty() {
        // Extract traces for R-hat calculation
        let chain1_traces: Vec<Trace> = chain1.into_iter().map(|(_, trace)| trace).collect();
        let chain2_traces: Vec<Trace> = chain2.into_iter().map(|(_, trace)| trace).collect();
        let r_hat = r_hat_f64(&[chain1_traces, chain2_traces], &addr!("mu"));
        let ess1 = effective_sample_size_mcmc(&chain1_samples);
        let ess2 = effective_sample_size_mcmc(&chain2_samples);

        println!("โœ… MCMC diagnostics computed");
        println!(
            "   - Chain 1: {} samples, ESS = {:.1}",
            chain1_samples.len(),
            ess1
        );
        println!(
            "   - Chain 2: {} samples, ESS = {:.1}",
            chain2_samples.len(),
            ess2
        );
        println!("   - R-hat: {:.4} (< 1.1 indicates convergence)", r_hat);

        if r_hat < 1.1 {
            println!("   - โœ… Chains appear to have converged");
        } else {
            println!("   - โš ๏ธ  Chains may not have converged - run longer");
        }
    }

Convergence Indicators:

  • R-hat < 1.1: Chains have converged
  • High ESS: Efficient sampling without excessive correlation
  • Multiple chains: Essential for reliable convergence assessment
  • Visual inspection: Always examine trace plots when possible

Model Structure Analysis

Model structure analysis reveals the computational graph and parameter dependencies. This analysis is crucial for understanding model complexity and identifying potential issues:

graph TD
    subgraph "Model Structure Hierarchy"
        A[Model M] --> B[Parameter Groups]
        B --> C1[Hyperpriors ฮธโ‚]
        B --> C2[Primary Parameters ฮธโ‚‚] 
        B --> C3[Observations y]
        C1 --> D1[Constraint Analysis]
        C2 --> D2[Dependency Graph]
        C3 --> D3[Likelihood Terms]
        D1 --> E[Structure Validation]
        D2 --> E
        D3 --> E
    end

Structural Invariants to validate:

  1. Address Uniqueness: (no collisions)
  2. Parameter Hierarchy:
  3. Choice Count Consistency: Expected vs. actual parameter count
  4. Type Safety: Each address maps to consistent value types
    // Create a complex model to demonstrate structure analysis
    let complex_model = || {
        prob!(
            // Hierarchical structure
            let global_scale <- sample(addr!("global_scale"), Gamma::new(2.0, 1.0).unwrap());

            let group_params <- plate!(g in 0..3 => {
                sample(addr!("group_mean", g), Normal::new(0.0, global_scale).unwrap())
                    .bind(move |mean| {
                        sample(addr!("group_precision", g), Gamma::new(2.0, 1.0).unwrap())
                            .map(move |prec| (mean, prec))
                    })
            });

            // Individual observations (simplified to avoid move issues)
            let observations = [1.2, 1.5, 0.8];
            let likelihoods <- plate!(i in 0..observations.len() => {
                // Use fixed parameters for demonstration
                observe(addr!("obs", i), Normal::new(0.0, 1.0).unwrap(), observations[i])
            });

            pure((global_scale, group_params, likelihoods))
        )
    };

    let (_result, complex_trace) = runtime::handler::run(
        PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        complex_model(),
    );

    // Analyze model structure
    let mut address_analysis = BTreeMap::new();
    for (addr, choice) in &complex_trace.choices {
        let addr_str = addr.0.clone();
        let category = if addr_str.contains("global") {
            "Global Parameters"
        } else if addr_str.contains("group") {
            "Group Parameters"
        } else if addr_str.contains("obs") {
            "Observations"
        } else {
            "Other"
        };

        address_analysis
            .entry(category)
            .or_insert(Vec::new())
            .push((addr_str, choice.logp));
    }

    println!("โœ… Complex model structure analysis");
    println!("   - Total choices: {}", complex_trace.choices.len());
    println!("   - Address structure:");
    for (category, addresses) in address_analysis {
        println!("     {}: {} choices", category, addresses.len());
        for (addr, logp) in addresses.iter().take(3) {
            // Show first 3
            println!("       {} (logp: {:.3})", addr, logp);
        }
        if addresses.len() > 3 {
            println!("       ... and {} more", addresses.len() - 3);
        }
    }

Structure Analysis Benefits:

  • Understand parameter organization and hierarchies
  • Detect unexpected address patterns
  • Verify choice counts match model expectations
  • Identify bottlenecks in complex models

Performance Diagnostics

Monitor computational efficiency and identify bottlenecks:

    use std::time::Instant;

    // Benchmark model execution and trace construction
    let benchmark_model = || {
        prob!(
            let params <- plate!(i in 0..100 => {
                sample(addr!("param", i), Normal::new(0.0, 1.0).unwrap())
            });
            pure(params)
        )
    };

    let start = Instant::now();
    let (_, bench_trace) = runtime::handler::run(
        PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        benchmark_model(),
    );
    let execution_time = start.elapsed();

    // Analyze trace characteristics
    let choice_count = bench_trace.choices.len();
    let memory_estimate = choice_count * 64; // Rough estimate
    let log_weight_is_finite = bench_trace.total_log_weight().is_finite();

    println!("โœ… Performance diagnostics");
    println!("   - Execution time: {:?}", execution_time);
    println!("   - Choices created: {}", choice_count);
    println!("   - Memory estimate: ~{} bytes", memory_estimate);
    println!("   - Log-weight valid: {}", log_weight_is_finite);

    // Check for potential issues
    if choice_count == 0 {
        println!("   - โš ๏ธ  No choices recorded - possible model issue");
    }
    if !log_weight_is_finite {
        println!("   - โš ๏ธ  Invalid log-weight - check factors and observations");
    }
    if execution_time.as_millis() > 100 {
        println!("   - โš ๏ธ  Slow execution - consider optimization");
    }

Performance Warning Signs:

  • Zero choices recorded (model execution failure)
  • Infinite log-weights (constraint violations)
  • Excessive execution time (optimization needed)
  • Large memory footprint (consider streaming approaches)

Common Debugging Patterns

Systematic debugging follows a hierarchical validation strategy from basic correctness to statistical validity:

graph TD
    subgraph "Debugging Methodology"
        A[Model Implementation] --> B{Syntax Valid?}
        B -->|No| C[Fix Code Structure]
        B -->|Yes| D{Types Consistent?}
        D -->|No| E[Fix Type Errors]
        D -->|Yes| F{Finite Log-Weights?}
        F -->|No| G[Fix Constraints]
        F -->|Yes| H{Statistical Properties?}
        H -->|No| I[Validate Distributions]
        H -->|Yes| J{Convergence?}
        J -->|No| K[Tune Inference]
        J -->|Yes| L[Model Validated]
        
        C --> A
        E --> A
        G --> A
        I --> A
        K --> A
    end

Debug Level Hierarchy:

  1. Syntactic: Code compiles and types check
  2. Semantic: Model executes without runtime errors
  3. Numerical: Computations remain stable and finite
  4. Statistical: Results match theoretical expectations
  5. Convergence: Inference algorithms reach stationarity
    // Pattern 1: Systematic model testing
    fn test_model_basic_properties<F, T>(
        model_fn: F,
        expected_choice_count: usize,
        description: &str,
    ) where
        F: Fn() -> Model<T>,
    {
        let mut rng = thread_rng();
        let (_, trace) = runtime::handler::run(
            PriorHandler {
                rng: &mut rng,
                trace: Trace::default(),
            },
            model_fn(),
        );

        println!("Testing {}", description);

        // Basic trace validity
        assert!(
            trace.total_log_weight().is_finite(),
            "Log-weight should be finite"
        );
        assert_eq!(
            trace.choices.len(),
            expected_choice_count,
            "Choice count mismatch"
        );

        // Check for common issues
        if trace.log_prior.is_infinite() {
            println!("  - โš ๏ธ  Infinite prior - check parameter ranges");
        }
        if trace.log_likelihood.is_infinite() {
            println!("  - โš ๏ธ  Infinite likelihood - check observations");
        }
        if trace.log_factors.is_infinite() {
            println!("  - โš ๏ธ  Infinite factors - check constraint satisfaction");
        }

        println!("  - โœ… {} passed basic tests", description);
    }

    // Test simple models
    test_model_basic_properties(
        || sample(addr!("x"), Normal::new(0.0, 1.0).unwrap()),
        1,
        "Simple normal sampling",
    );

    test_model_basic_properties(
        || {
            prob!(
                let x <- sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
                observe(addr!("y"), Normal::new(x, 0.5).unwrap(), 1.0);
                pure(x)
            )
        },
        1,
        "Normal model with observation",
    );

    // Pattern 2: Address collision detection
    fn check_address_collisions(trace: &Trace) -> Vec<String> {
        let mut collisions = Vec::new();
        let addresses: Vec<&str> = trace.choices.keys().map(|addr| addr.0.as_str()).collect();

        for (i, addr1) in addresses.iter().enumerate() {
            for addr2 in addresses.iter().skip(i + 1) {
                if addr1 == addr2 {
                    collisions.push(format!("Duplicate address: {}", addr1));
                }
            }
        }
        collisions
    }

    let test_trace = complex_trace; // Use trace from earlier
    let collisions = check_address_collisions(&test_trace);
    if collisions.is_empty() {
        println!("  - โœ… No address collisions detected");
    } else {
        for collision in collisions {
            println!("  - โš ๏ธ  {}", collision);
        }
    }

    println!("โœ… Debugging patterns demonstration complete");

Debugging Workflow:

  1. Start Simple: Test individual components before complex composition
  2. Validate Incrementally: Add complexity one piece at a time
  3. Check Address Uniqueness: Prevent parameter collision bugs
  4. Monitor Log-Weights: Track prior, likelihood, and factor contributions
  5. Use Systematic Testing: Automated validation for all model components

Testing Framework Integration

Embed debugging checks in your test suite:

    #[test]
    fn test_trace_inspection_patterns() {
        let mut rng = thread_rng();

        let model = prob!(
            let x <- sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
            let y <- sample(addr!("y"), Beta::new(1.0, 1.0).unwrap());
            observe(addr!("obs"), Normal::new(x, 0.1).unwrap(), 1.5);
            pure((x, y))
        );

        let (_result, trace) = runtime::handler::run(
            PriorHandler {
                rng: &mut rng,
                trace: Trace::default(),
            },
            model,
        );

        // Basic trace properties
        assert_eq!(trace.choices.len(), 2); // x and y samples
        assert!(trace.total_log_weight().is_finite());
        assert!(trace.log_likelihood.is_finite());

        // Type-safe access
        assert!(trace.get_f64(&addr!("x")).is_some());
        assert!(trace.get_f64(&addr!("y")).is_some());
        assert!(trace.get_bool(&addr!("x")).is_none()); // Type mismatch

        // Result access patterns
        assert!(trace.get_f64_result(&addr!("x")).is_ok());
        assert!(trace.get_f64_result(&addr!("missing")).is_err());
    }

    #[test]
    fn test_safe_vs_strict_handlers() {
        let mut rng = thread_rng();

        // Create base trace
        let mut base_trace = Trace::default();
        base_trace.insert_choice(addr!("param"), ChoiceValue::F64(2.5), -1.0);

        let model = sample(addr!("param"), Normal::new(0.0, 1.0).unwrap());

        // Safe replay should work
        let safe_handler = SafeReplayHandler {
            rng: &mut rng,
            base: base_trace,
            trace: Trace::default(),
            warn_on_mismatch: false,
        };
        let (result, trace) = runtime::handler::run(safe_handler, model);

        assert_eq!(result, 2.5);
        assert_eq!(trace.get_f64(&addr!("param")), Some(2.5));
    }

    #[test]
    fn test_model_structure_analysis() {
        let mut rng = thread_rng();

        let hierarchical_model = || {
            prob!(
                let global <- sample(addr!("global"), Normal::new(0.0, 1.0).unwrap());
                let locals <- plate!(i in 0..3 => {
                    sample(addr!("local", i), Normal::new(global, 0.1).unwrap())
                });
                pure((global, locals))
            )
        };

        let (_, trace) = runtime::handler::run(
            PriorHandler {
                rng: &mut rng,
                trace: Trace::default(),
            },
            hierarchical_model(),
        );

        // Should have global + 3 local parameters
        assert_eq!(trace.choices.len(), 4);

        // Check address structure
        assert!(trace.choices.contains_key(&addr!("global")));
        assert!(trace.choices.contains_key(&addr!("local", 0)));
        assert!(trace.choices.contains_key(&addr!("local", 1)));
        assert!(trace.choices.contains_key(&addr!("local", 2)));
    }

    #[test]
    fn test_performance_diagnostics() {
        use std::time::Instant;
        let mut rng = thread_rng();

        let large_model = || {
            plate!(i in 0..50 => {
                sample(addr!("x", i), Normal::new(0.0, 1.0).unwrap())
            })
        };

        let start = Instant::now();
        let (_, trace) = runtime::handler::run(
            PriorHandler {
                rng: &mut rng,
                trace: Trace::default(),
            },
            large_model(),
        );
        let duration = start.elapsed();

        assert_eq!(trace.choices.len(), 50);
        assert!(trace.total_log_weight().is_finite());

        // Performance should be reasonable
        assert!(duration.as_millis() < 1000, "Model execution too slow");
    }

Testing Strategy:

  • Unit tests for individual model components
  • Integration tests for complete workflows
  • Performance regression tests
  • Statistical validation against known results

Common Issues and Solutions

Issue: Infinite Log-Weights

Symptoms: trace.total_log_weight().is_infinite()

Causes:

  • Factor statements with impossible constraints
  • Parameters outside valid ranges
  • Numerical overflow in likelihood computations

Solutions:

  • Check factor conditions carefully
  • Validate parameter ranges in constructors
  • Use log-space computations for numerical stability

Issue: Missing or Wrong Parameter Values

Symptoms: get_*() returns None or wrong types

Causes:

  • Address typos or inconsistencies
  • Model structure doesn't match expectations
  • Type mismatches in trace replay

Solutions:

  • Use consistent address naming conventions
  • Print all addresses for verification
  • Use safe handlers for production resilience

Issue: Poor MCMC Convergence

Symptoms: High R-hat values, low ESS

Causes:

  • Inappropriate step sizes
  • Poor model parameterization
  • Insufficient warm-up periods

Solutions:

  • Increase warm-up iterations
  • Reparameterize for better geometry
  • Use adaptive algorithms with proper tuning

Issue: Slow Model Execution

Symptoms: High execution times, memory usage

Causes:

  • Inefficient model structure
  • Excessive address creation
  • Large trace construction overhead

Solutions:

  • Use plate! for vectorized operations
  • Pre-allocate data structures when possible
  • Profile with performance diagnostics

Best Practices Summary

  1. Debug Incrementally: Start simple and add complexity systematically
  2. Use All Tools: Combine trace inspection, validation, and diagnostics
  3. Test Edge Cases: Verify behavior at parameter boundaries
  4. Monitor Performance: Track execution time and memory usage
  5. Validate Statistically: Compare against known theoretical results
  6. Handle Errors Gracefully: Use safe handlers in production
  7. Document Assumptions: Clear model specifications aid debugging

Effective debugging transforms probabilistic programming from guesswork into systematic model development. Fugue's comprehensive debugging toolkit enables confident deployment of complex probabilistic systems.

Custom Handlers

Fugue's handler system is grounded in algebraic effect theory, providing a principled approach to effect interpretation and computational extension. Custom handlers enable specialized execution strategies, monitoring systems, and novel inference algorithms through systematic effect handling and handler composition.

Algebraic Effects Foundation

Fugue models effects through an algebra where:

  • is the set of effect operations (sample, observe, factor)
  • is the signature defining operation types
  • Handlers provide interpretations into a carrier

This algebraic structure ensures compositional semantics and modular interpretation.

Understanding the Handler Trait

The Handler trait provides the algebraic signature for probabilistic effects. Each method represents an effect operation with its semantic interpretation:

graph TD
    subgraph "Handler Architecture"
        A[Effect E] --> B{Effect Type}
        B -->|sample| C[on_sample_T]
        B -->|observe| D[on_observe_T]  
        B -->|factor| E[on_factor]
        C --> F[Handler State H]
        D --> F
        E --> F
        F --> G[Updated State H']
        G --> H[Continue Execution]
    end

Effect Algebra: Each handler interprets the probabilistic effect signature:

where the carrier type varies by handler implementation.

/// Simple handler that just samples from priors (similar to PriorHandler)
struct BasicHandler<R: Rng> {
    rng: R,
    trace: Trace,
}

impl<R: Rng> Handler for BasicHandler<R> {
    fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
        let value = dist.sample(&mut self.rng);
        let log_prob = dist.log_prob(&value);

        // Store in trace
        self.trace.log_prior += log_prob;
        self.trace.choices.insert(
            addr.clone(),
            Choice {
                addr: addr.clone(),
                value: ChoiceValue::F64(value),
                logp: log_prob,
            },
        );

        value
    }

    fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
        let value = dist.sample(&mut self.rng);
        let log_prob = dist.log_prob(&value);

        self.trace.log_prior += log_prob;
        self.trace.choices.insert(
            addr.clone(),
            Choice {
                addr: addr.clone(),
                value: ChoiceValue::Bool(value),
                logp: log_prob,
            },
        );

        value
    }

    fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
        let value = dist.sample(&mut self.rng);
        let log_prob = dist.log_prob(&value);

        self.trace.log_prior += log_prob;
        self.trace.choices.insert(
            addr.clone(),
            Choice {
                addr: addr.clone(),
                value: ChoiceValue::U64(value),
                logp: log_prob,
            },
        );

        value
    }

    fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
        let value = dist.sample(&mut self.rng);
        let log_prob = dist.log_prob(&value);

        self.trace.log_prior += log_prob;
        self.trace.choices.insert(
            addr.clone(),
            Choice {
                addr: addr.clone(),
                value: ChoiceValue::Usize(value),
                logp: log_prob,
            },
        );

        value
    }

    fn on_observe_f64(&mut self, _addr: &Address, dist: &dyn Distribution<f64>, value: f64) {
        self.trace.log_likelihood += dist.log_prob(&value);
    }

    fn on_observe_bool(&mut self, _addr: &Address, dist: &dyn Distribution<bool>, value: bool) {
        self.trace.log_likelihood += dist.log_prob(&value);
    }

    fn on_observe_u64(&mut self, _addr: &Address, dist: &dyn Distribution<u64>, value: u64) {
        self.trace.log_likelihood += dist.log_prob(&value);
    }

    fn on_observe_usize(&mut self, _addr: &Address, dist: &dyn Distribution<usize>, value: usize) {
        self.trace.log_likelihood += dist.log_prob(&value);
    }

    fn on_factor(&mut self, log_weight: f64) {
        self.trace.log_factors += log_weight;
    }

    fn finish(self) -> Trace {
        self.trace
    }
}

Handler Responsibilities:

  • Type-specific sampling: Handle f64, bool, u64, and usize distributions appropriately
  • Observation handling: Process observed values and update likelihood components
  • Factor management: Accumulate constraint and penalty terms
  • Trace construction: Build execution traces with choices and log-weights
  • Resource cleanup: Properly finalize and return traces

Decorator Pattern for Handler Composition

The decorator pattern implements handler composition through effect forwarding with computational augmentation. This pattern follows the mathematical principle of function composition:

Applied to handlers:

graph LR
    subgraph "Handler Composition Chain"
        A[Effect] --> B[Decoratorโ‚]
        B --> C[Decoratorโ‚‚]
        C --> D[Base Handler]
        D --> E[Result]

        B -.->|"Log, Monitor"| F[Side Effects]
        C -.->|"Transform, Filter"| G[Modifications]
    end

Compositional Properties:

  • Associativity:
  • Identity:
  • Effect Preservation: Core semantics remain unchanged
/// Handler decorator that logs all operations
struct LoggingHandler<H: Handler> {
    inner: H,
    log: Vec<String>,
    verbose: bool,
}

impl<H: Handler> LoggingHandler<H> {
    fn new(inner: H, verbose: bool) -> Self {
        Self {
            inner,
            log: Vec::new(),
            verbose,
        }
    }

    fn log_operation(&mut self, operation: String) {
        if self.verbose {
            println!("LOG: {}", operation);
        }
        self.log.push(operation);
    }
}

impl<H: Handler> Handler for LoggingHandler<H> {
    fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
        let value = self.inner.on_sample_f64(addr, dist);
        self.log_operation(format!("Sample f64 at {}: {:.3}", addr, value));
        value
    }

    fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
        let value = self.inner.on_sample_bool(addr, dist);
        self.log_operation(format!("Sample bool at {}: {}", addr, value));
        value
    }

    fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
        let value = self.inner.on_sample_u64(addr, dist);
        self.log_operation(format!("Sample u64 at {}: {}", addr, value));
        value
    }

    fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
        let value = self.inner.on_sample_usize(addr, dist);
        self.log_operation(format!("Sample usize at {}: {}", addr, value));
        value
    }

    fn on_observe_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>, value: f64) {
        self.log_operation(format!("Observe f64 at {}: {:.3}", addr, value));
        self.inner.on_observe_f64(addr, dist, value);
    }

    fn on_observe_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>, value: bool) {
        self.log_operation(format!("Observe bool at {}: {}", addr, value));
        self.inner.on_observe_bool(addr, dist, value);
    }

    fn on_observe_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>, value: u64) {
        self.log_operation(format!("Observe u64 at {}: {}", addr, value));
        self.inner.on_observe_u64(addr, dist, value);
    }

    fn on_observe_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>, value: usize) {
        self.log_operation(format!("Observe usize at {}: {}", addr, value));
        self.inner.on_observe_usize(addr, dist, value);
    }

    fn on_factor(&mut self, log_weight: f64) {
        self.log_operation(format!("Factor: {:.3}", log_weight));
        self.inner.on_factor(log_weight);
    }

    fn finish(self) -> Trace {
        let trace = self.inner.finish();
        println!("โœ… Logged {} operations total", self.log.len());
        trace
    }
}

Decorator Benefits:

  • Non-invasive functionality addition
  • Composable and reusable components
  • Separation of concerns between core logic and cross-cutting features
  • Easy to enable/disable features dynamically

Stateful Handlers for Analytics

Handlers can maintain state to accumulate statistics and monitor model behavior:

/// Handler that accumulates statistics about model execution
#[derive(Debug)]
struct ExecutionStats {
    sample_counts: HashMap<String, u32>, // Type -> count
    observe_counts: HashMap<String, u32>,
    factor_count: u32,
    total_log_weight: f64,
    parameter_ranges: HashMap<String, (f64, f64)>, // Address -> (min, max) for f64 params
}

impl Default for ExecutionStats {
    fn default() -> Self {
        Self {
            sample_counts: HashMap::new(),
            observe_counts: HashMap::new(),
            factor_count: 0,
            total_log_weight: 0.0,
            parameter_ranges: HashMap::new(),
        }
    }
}

struct StatisticsHandler<H: Handler> {
    inner: H,
    stats: ExecutionStats,
}

impl<H: Handler> StatisticsHandler<H> {
    fn new(inner: H) -> Self {
        Self {
            inner,
            stats: ExecutionStats::default(),
        }
    }

    fn update_f64_range(&mut self, addr: &Address, value: f64) {
        let key = addr.0.clone();
        self.stats
            .parameter_ranges
            .entry(key)
            .and_modify(|(min, max)| {
                *min = min.min(value);
                *max = max.max(value);
            })
            .or_insert((value, value));
    }
}

impl<H: Handler> Handler for StatisticsHandler<H> {
    fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
        let value = self.inner.on_sample_f64(addr, dist);
        *self
            .stats
            .sample_counts
            .entry("f64".to_string())
            .or_insert(0) += 1;
        self.update_f64_range(addr, value);
        value
    }

    fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
        let value = self.inner.on_sample_bool(addr, dist);
        *self
            .stats
            .sample_counts
            .entry("bool".to_string())
            .or_insert(0) += 1;
        value
    }

    fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
        let value = self.inner.on_sample_u64(addr, dist);
        *self
            .stats
            .sample_counts
            .entry("u64".to_string())
            .or_insert(0) += 1;
        value
    }

    fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
        let value = self.inner.on_sample_usize(addr, dist);
        *self
            .stats
            .sample_counts
            .entry("usize".to_string())
            .or_insert(0) += 1;
        value
    }

    fn on_observe_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>, value: f64) {
        *self
            .stats
            .observe_counts
            .entry("f64".to_string())
            .or_insert(0) += 1;
        self.inner.on_observe_f64(addr, dist, value);
    }

    fn on_observe_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>, value: bool) {
        *self
            .stats
            .observe_counts
            .entry("bool".to_string())
            .or_insert(0) += 1;
        self.inner.on_observe_bool(addr, dist, value);
    }

    fn on_observe_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>, value: u64) {
        *self
            .stats
            .observe_counts
            .entry("u64".to_string())
            .or_insert(0) += 1;
        self.inner.on_observe_u64(addr, dist, value);
    }

    fn on_observe_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>, value: usize) {
        *self
            .stats
            .observe_counts
            .entry("usize".to_string())
            .or_insert(0) += 1;
        self.inner.on_observe_usize(addr, dist, value);
    }

    fn on_factor(&mut self, log_weight: f64) {
        self.stats.factor_count += 1;
        self.stats.total_log_weight += log_weight;
        self.inner.on_factor(log_weight);
    }

    fn finish(self) -> Trace {
        println!("โœ… Execution Statistics:");
        println!("   - Samples by type: {:?}", self.stats.sample_counts);
        println!("   - Observations by type: {:?}", self.stats.observe_counts);
        println!("   - Factor operations: {}", self.stats.factor_count);
        println!("   - Parameter ranges:");
        for (addr, (min, max)) in &self.stats.parameter_ranges {
            println!("     {}: [{:.3}, {:.3}]", addr, min, max);
        }
        self.inner.finish()
    }
}

Analytics Applications:

  • Model complexity analysis (parameter counts by type)
  • Execution profiling and bottleneck identification
  • Parameter range monitoring for numerical stability
  • Distribution usage patterns for optimization

Conditional and Filtering Handlers

Implement business logic and constraints through conditional handling:

/// Handler that filters/modifies values based on conditions
struct FilteringHandler<H: Handler> {
    inner: H,
    f64_clamp_range: Option<(f64, f64)>,
    bool_flip_probability: f64,
    rng: rand::rngs::ThreadRng,
}

impl<H: Handler> FilteringHandler<H> {
    fn new(inner: H, f64_clamp_range: Option<(f64, f64)>, bool_flip_probability: f64) -> Self {
        Self {
            inner,
            f64_clamp_range,
            bool_flip_probability,
            rng: thread_rng(),
        }
    }
}

impl<H: Handler> Handler for FilteringHandler<H> {
    fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
        let mut value = self.inner.on_sample_f64(addr, dist);

        // Apply clamping if specified
        if let Some((min, max)) = self.f64_clamp_range {
            value = value.clamp(min, max);
        }

        value
    }

    fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
        let mut value = self.inner.on_sample_bool(addr, dist);

        // Flip boolean with specified probability
        if self.rng.gen::<f64>() < self.bool_flip_probability {
            value = !value;
        }

        value
    }

    fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
        self.inner.on_sample_u64(addr, dist)
    }

    fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
        self.inner.on_sample_usize(addr, dist)
    }

    fn on_observe_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>, value: f64) {
        self.inner.on_observe_f64(addr, dist, value);
    }

    fn on_observe_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>, value: bool) {
        self.inner.on_observe_bool(addr, dist, value);
    }

    fn on_observe_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>, value: u64) {
        self.inner.on_observe_u64(addr, dist, value);
    }

    fn on_observe_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>, value: usize) {
        self.inner.on_observe_usize(addr, dist, value);
    }

    fn on_factor(&mut self, log_weight: f64) {
        self.inner.on_factor(log_weight);
    }

    fn finish(self) -> Trace {
        self.inner.finish()
    }
}

Filtering Use Cases:

  • Parameter clamping for numerical stability
  • Outlier detection and handling
  • Domain-specific constraints enforcement
  • Robustness testing through perturbations

Performance Monitoring

Track and optimize computational characteristics with monitoring handlers:

use std::time::{Duration, Instant};

/// Handler that monitors performance characteristics
struct PerformanceHandler<H: Handler> {
    inner: H,
    start_time: Instant,
    operation_times: Vec<Duration>,
    sample_count: u32,
    observe_count: u32,
}

impl<H: Handler> PerformanceHandler<H> {
    fn new(inner: H) -> Self {
        Self {
            inner,
            start_time: Instant::now(),
            operation_times: Vec::new(),
            sample_count: 0,
            observe_count: 0,
        }
    }

    fn time_operation<F, R>(&mut self, operation: F) -> R
    where
        F: FnOnce(&mut H) -> R,
    {
        let start = Instant::now();
        let result = operation(&mut self.inner);
        let duration = start.elapsed();
        self.operation_times.push(duration);
        result
    }
}

impl<H: Handler> Handler for PerformanceHandler<H> {
    fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
        self.sample_count += 1;
        self.time_operation(|inner| inner.on_sample_f64(addr, dist))
    }

    fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
        self.sample_count += 1;
        self.time_operation(|inner| inner.on_sample_bool(addr, dist))
    }

    fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
        self.sample_count += 1;
        self.time_operation(|inner| inner.on_sample_u64(addr, dist))
    }

    fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
        self.sample_count += 1;
        self.time_operation(|inner| inner.on_sample_usize(addr, dist))
    }

    fn on_observe_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>, value: f64) {
        self.observe_count += 1;
        self.time_operation(|inner| inner.on_observe_f64(addr, dist, value))
    }

    fn on_observe_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>, value: bool) {
        self.observe_count += 1;
        self.time_operation(|inner| inner.on_observe_bool(addr, dist, value))
    }

    fn on_observe_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>, value: u64) {
        self.observe_count += 1;
        self.time_operation(|inner| inner.on_observe_u64(addr, dist, value))
    }

    fn on_observe_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>, value: usize) {
        self.observe_count += 1;
        self.time_operation(|inner| inner.on_observe_usize(addr, dist, value))
    }

    fn on_factor(&mut self, log_weight: f64) {
        self.time_operation(|inner| inner.on_factor(log_weight))
    }

    fn finish(self) -> Trace {
        let total_time = self.start_time.elapsed();
        let avg_op_time = if !self.operation_times.is_empty() {
            self.operation_times.iter().sum::<Duration>() / self.operation_times.len() as u32
        } else {
            Duration::ZERO
        };

        println!("โœ… Performance Monitoring Results:");
        println!("   - Total execution time: {:?}", total_time);
        println!("   - Operations performed: {}", self.operation_times.len());
        println!("   - Sample operations: {}", self.sample_count);
        println!("   - Observe operations: {}", self.observe_count);
        println!("   - Average operation time: {:?}", avg_op_time);

        self.inner.finish()
    }
}

Performance Insights:

  • Operation timing and bottleneck identification
  • Memory allocation patterns
  • Execution hotspots and optimization opportunities
  • Scalability analysis for production deployment

Custom Inference Algorithms

Custom inference algorithms extend Fugue's effect interpretation to implement novel sampling strategies and approximate inference methods. Each algorithm provides a unique semantic mapping from probabilistic effects to computational actions:

graph TD
    subgraph "Inference Algorithm Architecture"
        A[Model M] --> B[Effect Sequence]
        B --> C{Handler Type}
        C -->|MCMC| D[Markov Chain<br/>Sampling]
        C -->|VI| E[Variational<br/>Approximation]
        C -->|SMC| F[Sequential<br/>Monte Carlo]
        C -->|ABC| G[Approximate<br/>Bayesian Computation]

        D --> H[Posterior Samples]
        E --> I[Approximate<br/>Distribution]
        F --> J[Weighted<br/>Particles]
        G --> K[Likelihood-Free<br/>Samples]
    end

Algorithm Design Principles:

  1. Effect Consistency: preserves probabilistic semantics
  2. Convergence Guarantees: Algorithm converges to target distribution under regularity conditions
  3. Computational Tractability: Runtime complexity is polynomial in problem dimensions
  4. Statistical Efficiency: Effective sample size scales appropriately with computational cost

Mathematical Framework: Each inference handler implements a stochastic operator with fixed point such that .

/// Simple custom MCMC-like handler that perturbs existing values
struct SimpleMCMCHandler<R: Rng> {
    rng: R,
    base_trace: Trace,
    current_trace: Trace,
    perturbation_scale: f64,
}

impl<R: Rng> SimpleMCMCHandler<R> {
    fn new(rng: R, base_trace: Trace, perturbation_scale: f64) -> Self {
        Self {
            rng,
            base_trace,
            current_trace: Trace::default(),
            perturbation_scale,
        }
    }
}

impl<R: Rng> Handler for SimpleMCMCHandler<R> {
    fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
        let value = if let Some(base_value) = self.base_trace.get_f64(addr) {
            // Perturb existing value
            let perturbation = Normal::new(0.0, self.perturbation_scale).unwrap();
            base_value + perturbation.sample(&mut self.rng)
        } else {
            // Sample fresh if not in base trace
            dist.sample(&mut self.rng)
        };

        let log_prob = dist.log_prob(&value);
        self.current_trace.log_prior += log_prob;
        self.current_trace.choices.insert(
            addr.clone(),
            Choice {
                addr: addr.clone(),
                value: ChoiceValue::F64(value),
                logp: log_prob,
            },
        );

        value
    }

    fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
        let value = if let Some(base_value) = self.base_trace.get_bool(addr) {
            // Maybe flip the boolean with small probability
            if self.rng.gen::<f64>() < 0.1 {
                !base_value
            } else {
                base_value
            }
        } else {
            dist.sample(&mut self.rng)
        };

        let log_prob = dist.log_prob(&value);
        self.current_trace.log_prior += log_prob;
        self.current_trace.choices.insert(
            addr.clone(),
            Choice {
                addr: addr.clone(),
                value: ChoiceValue::Bool(value),
                logp: log_prob,
            },
        );

        value
    }

    fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
        // For simplicity, just use base value or sample fresh
        let value = self
            .base_trace
            .get_u64(addr)
            .unwrap_or_else(|| dist.sample(&mut self.rng));

        let log_prob = dist.log_prob(&value);
        self.current_trace.log_prior += log_prob;
        self.current_trace.choices.insert(
            addr.clone(),
            Choice {
                addr: addr.clone(),
                value: ChoiceValue::U64(value),
                logp: log_prob,
            },
        );

        value
    }

    fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
        let value = self
            .base_trace
            .get_usize(addr)
            .unwrap_or_else(|| dist.sample(&mut self.rng));

        let log_prob = dist.log_prob(&value);
        self.current_trace.log_prior += log_prob;
        self.current_trace.choices.insert(
            addr.clone(),
            Choice {
                addr: addr.clone(),
                value: ChoiceValue::Usize(value),
                logp: log_prob,
            },
        );

        value
    }

    fn on_observe_f64(&mut self, _addr: &Address, dist: &dyn Distribution<f64>, value: f64) {
        self.current_trace.log_likelihood += dist.log_prob(&value);
    }

    fn on_observe_bool(&mut self, _addr: &Address, dist: &dyn Distribution<bool>, value: bool) {
        self.current_trace.log_likelihood += dist.log_prob(&value);
    }

    fn on_observe_u64(&mut self, _addr: &Address, dist: &dyn Distribution<u64>, value: u64) {
        self.current_trace.log_likelihood += dist.log_prob(&value);
    }

    fn on_observe_usize(&mut self, _addr: &Address, dist: &dyn Distribution<usize>, value: usize) {
        self.current_trace.log_likelihood += dist.log_prob(&value);
    }

    fn on_factor(&mut self, log_weight: f64) {
        self.current_trace.log_factors += log_weight;
    }

    fn finish(self) -> Trace {
        self.current_trace
    }
}

Inference Handler Patterns:

  • MCMC variants: Custom proposal mechanisms and acceptance criteria
  • Variational methods: Gradient-based optimization with custom families
  • Rejection sampling: Domain-specific acceptance/rejection logic
  • Importance sampling: Custom proposal distributions and weight calculations

Handler Composition and Chaining

Handler chaining implements multi-stage effect processing through systematic composition operators. The composition forms a computational pipeline with well-defined data flow and effect propagation:

graph TD
    subgraph "Handler Composition Pipeline"
        A[Raw Effect E] --> B[Statistics Handler]
        B --> C[Logging Handler]
        C --> D[Performance Handler]
        D --> E[Base Handler]
        E --> F[Result + Trace]

        B -.->|Metrics| G[(Statistics DB)]
        C -.->|Events| H[(Log Stream)]
        D -.->|Timing| I[(Performance Monitor)]

        F --> J{Validation}
        J -->|Pass| K[Success]
        J -->|Fail| L[Error Recovery]
    end

Composition Laws:

  1. Preservation: preserves effect semantics
  2. Associativity: Composition order affects performance but not correctness
  3. Commutativity: Decorators with disjoint side effects commute
  4. Distributivity: for effect unions

Performance Analysis: Handler chain depth introduces overhead where is the per-handler cost. Optimization strategies include handler fusion and effect batching.

fn main() {
    println!("=== Custom Handlers in Fugue ===\n");

    println!("1. Basic Custom Handler Implementation");
    println!("------------------------------------");

    // Test the basic handler
    let mut rng = thread_rng();
    let handler = BasicHandler {
        rng: &mut rng,
        trace: Trace::default(),
    };

    let test_model = || sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
    let (result, trace) = runtime::handler::run(handler, test_model());

    println!("โœ… Basic handler executed");
    println!("   - Result: {:.3}", result);
    println!("   - Trace choices: {}", trace.choices.len());
    println!("   - Total log-weight: {:.3}", trace.total_log_weight());
    println!();

    println!("2. Logging Handler - Decorator Pattern");
    println!("-------------------------------------");

    // Test the logging handler
    let mut rng = thread_rng();
    let base_handler = PriorHandler {
        rng: &mut rng,
        trace: Trace::default(),
    };
    let logging_handler = LoggingHandler::new(base_handler, false); // Non-verbose

    let logged_model = || {
        prob!(
            let x <- sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap());
            observe(addr!("obs"), Normal::new(x, 0.5).unwrap(), 1.2);
            factor(-0.5);
            pure(x)
        )
    };

    let (result, _trace) = runtime::handler::run(logging_handler, logged_model());
    println!("   - Logged execution result: {:.3}", result);
    println!();

    println!("3. Statistics Accumulating Handler");
    println!("--------------------------------");

    // Test the statistics handler
    let mut rng = thread_rng();
    let base_handler = PriorHandler {
        rng: &mut rng,
        trace: Trace::default(),
    };
    let stats_handler = StatisticsHandler::new(base_handler);

    let complex_model = || {
        prob!(
            let mu <- sample(addr!("mu"), Normal::new(0.0, 2.0).unwrap());
            let is_outlier <- sample(addr!("outlier"), Bernoulli::new(0.1).unwrap());
            let count <- sample(addr!("count"), Poisson::new(3.0).unwrap());
            let category <- sample(addr!("category"), Categorical::new(vec![0.3, 0.4, 0.3]).unwrap());

            observe(addr!("y1"), Normal::new(mu, 1.0).unwrap(), 1.5);
            observe(addr!("y2"), Normal::new(mu, 1.0).unwrap(), 2.1);
            factor(if is_outlier { -2.0 } else { 0.0 });

            pure((mu, is_outlier, count, category))
        )
    };

    let (result, _trace) = runtime::handler::run(stats_handler, complex_model());
    println!(
        "   - Complex model result: {:?}",
        (result.0.round(), result.1, result.2, result.3)
    );
    println!();

    println!("4. Conditional Filtering Handler");
    println!("-------------------------------");

    // Test the filtering handler
    let mut rng = thread_rng();
    let base_handler = PriorHandler {
        rng: &mut rng,
        trace: Trace::default(),
    };
    let filtering_handler = FilteringHandler::new(
        base_handler,
        Some((-2.0, 2.0)), // Clamp f64 values to [-2, 2]
        0.1,               // 10% chance to flip booleans
    );

    let filter_test_model = || {
        prob!(
            let x <- sample(addr!("x"), Normal::new(0.0, 5.0).unwrap()); // Wide distribution
            let flag <- sample(addr!("flag"), Bernoulli::new(0.8).unwrap());
            pure((x, flag))
        )
    };

    let (result, _trace) = runtime::handler::run(filtering_handler, filter_test_model());
    println!("โœ… Filtering handler executed");
    println!("   - Clamped value: {:.3} (should be in [-2, 2])", result.0);
    println!(
        "   - Boolean value: {} (may be flipped from original)",
        result.1
    );
    println!();

    println!("5. Performance Monitoring Handler");
    println!("--------------------------------");

    // Test the performance handler
    let mut rng = thread_rng();
    let base_handler = PriorHandler {
        rng: &mut rng,
        trace: Trace::default(),
    };
    let perf_handler = PerformanceHandler::new(base_handler);

    let perf_test_model = || {
        plate!(i in 0..10 => {
            sample(addr!("param", i), Normal::new(0.0, 1.0).unwrap())
        })
    };

    let (_result, _trace) = runtime::handler::run(perf_handler, perf_test_model());
    println!();

    println!("6. Custom Inference Handler");
    println!("---------------------------");

    // Test the custom inference handler
    let mut rng1 = thread_rng();
    let rng2 = thread_rng();

    // First get a base trace
    let base_handler = PriorHandler {
        rng: &mut rng1,
        trace: Trace::default(),
    };

    let inference_model = || {
        prob!(
            let mu <- sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap());
            observe(addr!("y"), Normal::new(mu, 0.5).unwrap(), 1.0);
            pure(mu)
        )
    };

    let (base_result, base_trace) = runtime::handler::run(base_handler, inference_model());

    // Now use custom MCMC handler to perturb it
    let base_log_weight = base_trace.total_log_weight();
    let mcmc_handler = SimpleMCMCHandler::new(rng2, base_trace, 0.1);
    let (mcmc_result, mcmc_trace) = runtime::handler::run(mcmc_handler, inference_model());

    println!("โœ… Custom MCMC-like inference:");
    println!("   - Base result: {:.3}", base_result);
    println!("   - MCMC result: {:.3}", mcmc_result);
    println!("   - Base log-weight: {:.3}", base_log_weight);
    println!("   - MCMC log-weight: {:.3}", mcmc_trace.total_log_weight());
    println!();

    println!("7. Handler Composition and Chaining");
    println!("----------------------------------");

    // Demonstrate composing multiple handler decorators
    let mut rng = thread_rng();
    let base_handler = PriorHandler {
        rng: &mut rng,
        trace: Trace::default(),
    };

    // Chain multiple decorators: Statistics -> Logging -> Performance -> Base
    let stats_handler = StatisticsHandler::new(base_handler);
    let logging_handler = LoggingHandler::new(stats_handler, false);
    let performance_handler = PerformanceHandler::new(logging_handler);

    let composition_model = || {
        prob!(
            let x <- sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
            let y <- sample(addr!("y"), Bernoulli::new(0.7).unwrap());
            observe(addr!("obs"), Normal::new(x, 0.2).unwrap(), 0.5);
            factor(-0.3);
            pure((x, y))
        )
    };

    println!("โœ… Handler composition example:");
    let (_result, _trace) = runtime::handler::run(performance_handler, composition_model());
    println!("   - Multiple handler layers executed successfully");
    println!();

    println!("=== Custom Handler Patterns Demonstrated! ===");
}

Composition Strategies:

  • Layered approach: Statistics โ†’ Logging โ†’ Performance โ†’ Base
  • Conditional activation: Enable decorators based on environment/configuration
  • Feature flags: Runtime selection of handler combinations
  • Pipeline optimization: Order decorators for minimal overhead

Advanced Handler Patterns

Caching Handler

struct CachingHandler<H: Handler> {
    inner: H,
    cache: HashMap<(Address, String), ChoiceValue>, // Address + dist info -> cached value
}

Distributed Handler

struct DistributedHandler<H: Handler> {
    inner: H,
    worker_id: usize,
    coordinator: Arc<Mutex<SharedState>>,
}

Fault-Tolerant Handler

struct FaultTolerantHandler<H: Handler> {
    inner: H,
    fallback_strategy: FallbackMode,
    error_count: u32,
    max_errors: u32,
}

Testing Custom Handlers

Systematic testing ensures handler correctness:

    #[test]
    fn test_basic_custom_handler() {
        let mut rng = thread_rng();
        let handler = BasicHandler {
            rng: &mut rng,
            trace: Trace::default(),
        };

        let model = sample(addr!("test"), Normal::new(0.0, 1.0).unwrap());
        let (result, trace) = runtime::handler::run(handler, model);

        assert!(trace.choices.contains_key(&addr!("test")));
        assert!(trace.total_log_weight().is_finite());
        assert!(result.is_finite());
    }

    #[test]
    fn test_logging_handler() {
        let mut rng = thread_rng();
        let base_handler = PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        };
        let logging_handler = LoggingHandler::new(base_handler, false);

        let model = prob!(
            let x <- sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
            observe(addr!("obs"), Normal::new(x, 0.1).unwrap(), 1.0);
            pure(x)
        );

        let (result, trace) = runtime::handler::run(logging_handler, model);

        assert!(trace.choices.contains_key(&addr!("x")));
        assert!(trace.log_likelihood.is_finite());
        assert!(result.is_finite());
    }

    #[test]
    fn test_statistics_handler() {
        let mut rng = thread_rng();
        let base_handler = PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        };
        let stats_handler = StatisticsHandler::new(base_handler);

        let model = prob!(
            let x <- sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
            let flag <- sample(addr!("flag"), Bernoulli::new(0.5).unwrap());
            pure((x, flag))
        );

        let (result, trace) = runtime::handler::run(stats_handler, model);

        assert_eq!(trace.choices.len(), 2);
        assert!(result.0.is_finite());
    }

    #[test]
    fn test_handler_composition() {
        let mut rng = thread_rng();
        let base_handler = PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        };

        // Compose multiple handlers
        let logged_handler = LoggingHandler::new(base_handler, false);
        let stats_handler = StatisticsHandler::new(logged_handler);

        let model = prob!(
            let x <- sample(addr!("param"), Normal::new(0.0, 1.0).unwrap());
            factor(-0.5);
            pure(x)
        );

        let (result, trace) = runtime::handler::run(stats_handler, model);

        assert!(trace.choices.contains_key(&addr!("param")));
        assert!(trace.log_factors.abs() > 0.0); // Factor was applied
        assert!(result.is_finite());
    }

Testing Strategy:

  • Unit tests: Individual handler method behavior
  • Integration tests: Handler with realistic models
  • Property tests: Invariant verification across random inputs
  • Composition tests: Multi-layer handler combinations

Production Considerations

Error Handling

impl Handler for ProductionHandler {
    fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
        match self.inner.on_sample_f64(addr, dist) {
            value if value.is_finite() => value,
            _ => {
                self.log_error(addr, "Non-finite sample");
                0.0 // Safe fallback
            }
        }
    }
}

Memory Management

struct MemoryEfficientHandler<H: Handler> {
    inner: H,
    choice_pool: Vec<Choice>, // Reusable allocations
    max_trace_size: usize,
}

Monitoring Integration

struct MetricsHandler<H: Handler> {
    inner: H,
    metrics_client: MetricsClient,
    model_name: String,
}

Common Patterns Summary

  1. Decorator Pattern: Wrap handlers for additional functionality
  2. State Accumulation: Track statistics and model behavior
  3. Conditional Logic: Apply domain-specific constraints
  4. Performance Monitoring: Identify bottlenecks and optimization opportunities
  5. Custom Inference: Implement specialized algorithms
  6. Composition: Chain multiple handlers for comprehensive capabilities
  7. Error Handling: Graceful degradation and recovery
  8. Resource Management: Efficient memory and computation usage

Best Practices

  1. Single Responsibility: Each handler should have one clear purpose
  2. Composability: Design handlers to work well in combination
  3. Type Safety: Leverage Rust's type system for correctness
  4. Performance: Minimize overhead in hot paths
  5. Error Handling: Fail gracefully with meaningful diagnostics
  6. Testing: Comprehensive unit and integration tests
  7. Documentation: Clear API contracts and usage examples

Custom handlers transform Fugue from a probabilistic programming framework into a platform for building specialized inference systems, analytics tools, and production-ready probabilistic applications.

Production Deployment

Production deployment of probabilistic models requires reliability engineering, performance optimization, and operational excellence at scale. This guide establishes a mathematical framework for fault tolerance, service reliability, and system observability using Fugue's production-ready infrastructure patterns.

Reliability Theory Framework

Production systems exhibit stochastic reliability characterized by:

  • Availability: where MTBF = Mean Time Between Failures
  • Reliability Function: for exponential failure rates
  • Service Level Agreement:

Fugue's deployment patterns optimize these metrics through systematic fault isolation and graceful degradation.

Error Handling and Graceful Degradation

Graceful degradation implements fault tolerance through systematic error recovery and service continuity. The mathematical foundation relies on Markov reliability models and circuit breaker theory:

stateDiagram-v2
    [*] --> Healthy
    Healthy --> Degraded : Error Rate > ฯ„โ‚
    Degraded --> Failed : Error Rate > ฯ„โ‚‚  
    Failed --> Recovery : Time > T_recovery
    Recovery --> Healthy : Success Rate > ฯƒ
    Degraded --> Healthy : Error Rate < ฯ„โ‚€
    
    note right of Healthy
        Error Rate: ฮป < ฯ„โ‚€
        SLA: 99.9%
        Full Functionality
    end note
    
    note right of Degraded  
        Error Rate: ฯ„โ‚€ < ฮป < ฯ„โ‚
        SLA: 95%
        Limited Functionality
    end note
    
    note right of Failed
        Error Rate: ฮป > ฯ„โ‚‚
        Circuit Open
        Fallback Mode
    end note

Circuit Breaker Mathematics: The failure rate follows a Poisson process with rate . The circuit breaker transitions based on:

Error Budget Model: For SLA target , the error budget is:

/// Production-ready handler that gracefully handles failures
struct RobustProductionHandler<H: Handler> {
    inner: H,
    error_count: u32,
    max_errors: u32,
    _fallback_values: HashMap<String, ChoiceValue>,
    circuit_breaker_open: bool,
}

impl<H: Handler> RobustProductionHandler<H> {
    fn new(inner: H, max_errors: u32) -> Self {
        let mut fallback_values = HashMap::new();
        fallback_values.insert("default_f64".to_string(), ChoiceValue::F64(0.0));
        fallback_values.insert("default_bool".to_string(), ChoiceValue::Bool(false));
        fallback_values.insert("default_u64".to_string(), ChoiceValue::U64(0));
        fallback_values.insert("default_usize".to_string(), ChoiceValue::Usize(0));

        Self {
            inner,
            error_count: 0,
            max_errors,
            _fallback_values: fallback_values,
            circuit_breaker_open: false,
        }
    }

    fn handle_error(&mut self, operation: &str, addr: &Address) -> bool {
        self.error_count += 1;
        eprintln!("PRODUCTION ERROR: {} failed at address {}", operation, addr);

        if self.error_count >= self.max_errors {
            self.circuit_breaker_open = true;
            eprintln!("CIRCUIT BREAKER: Too many errors, switching to fallback mode");
        }

        self.circuit_breaker_open
    }

    fn get_fallback_f64(&self, addr: &Address) -> f64 {
        // In production, this might come from a cache, configuration, or ML model
        match addr.0.as_str() {
            s if s.contains("temperature") => 20.0,
            s if s.contains("price") => 100.0,
            s if s.contains("probability") => 0.5,
            _ => 0.0,
        }
    }
}

impl<H: Handler> Handler for RobustProductionHandler<H> {
    fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
        if self.circuit_breaker_open {
            return self.get_fallback_f64(addr);
        }

        match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
            self.inner.on_sample_f64(addr, dist)
        })) {
            Ok(value) if value.is_finite() => value,
            Ok(invalid_value) => {
                eprintln!("Invalid f64 sample: {} at {}", invalid_value, addr);
                self.handle_error("sample_f64", addr);
                self.get_fallback_f64(addr)
            }
            Err(_) => {
                self.handle_error("sample_f64_panic", addr);
                self.get_fallback_f64(addr)
            }
        }
    }

    fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
        if self.circuit_breaker_open {
            return false; // Safe fallback
        }

        match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
            self.inner.on_sample_bool(addr, dist)
        })) {
            Ok(value) => value,
            Err(_) => {
                self.handle_error("sample_bool_panic", addr);
                false
            }
        }
    }

    fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
        if self.circuit_breaker_open {
            return 1; // Safe default
        }

        match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
            self.inner.on_sample_u64(addr, dist)
        })) {
            Ok(value) => value,
            Err(_) => {
                self.handle_error("sample_u64_panic", addr);
                1
            }
        }
    }

    fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
        if self.circuit_breaker_open {
            return 0; // Safe array index
        }

        match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
            self.inner.on_sample_usize(addr, dist)
        })) {
            Ok(value) => value,
            Err(_) => {
                self.handle_error("sample_usize_panic", addr);
                0
            }
        }
    }

    fn on_observe_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>, value: f64) {
        if !self.circuit_breaker_open
            && std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
                self.inner.on_observe_f64(addr, dist, value)
            }))
            .is_err()
        {
            self.handle_error("observe_f64_panic", addr);
        }
    }

    fn on_observe_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>, value: bool) {
        if !self.circuit_breaker_open
            && std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
                self.inner.on_observe_bool(addr, dist, value)
            }))
            .is_err()
        {
            self.handle_error("observe_bool_panic", addr);
        }
    }

    fn on_observe_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>, value: u64) {
        if !self.circuit_breaker_open
            && std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
                self.inner.on_observe_u64(addr, dist, value)
            }))
            .is_err()
        {
            self.handle_error("observe_u64_panic", addr);
        }
    }

    fn on_observe_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>, value: usize) {
        if !self.circuit_breaker_open
            && std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
                self.inner.on_observe_usize(addr, dist, value)
            }))
            .is_err()
        {
            self.handle_error("observe_usize_panic", addr);
        }
    }

    fn on_factor(&mut self, log_weight: f64) {
        if !log_weight.is_finite() {
            eprintln!("Invalid factor log-weight: {}", log_weight);
            self.error_count += 1;
            return; // Skip invalid factors
        }

        if !self.circuit_breaker_open {
            self.inner.on_factor(log_weight);
        }
    }

    fn finish(self) -> Trace {
        println!("โœ… Production handler statistics:");
        println!(
            "   - Errors encountered: {}/{}",
            self.error_count, self.max_errors
        );
        println!(
            "   - Circuit breaker status: {}",
            if self.circuit_breaker_open {
                "OPEN (fallback mode)"
            } else {
                "CLOSED (normal)"
            }
        );

        if self.circuit_breaker_open {
            // Return minimal valid trace in fallback mode
            Trace::default()
        } else {
            self.inner.finish()
        }
    }
}

Robust Error Handling Features:

  • Circuit Breaker Pattern: Prevents cascade failures by switching to fallback mode
  • Panic Recovery: Catches panics and provides safe default values
  • Input Validation: Ensures all inputs are finite and within expected ranges
  • Fallback Values: Domain-specific defaults for different parameter types
  • Error Counting: Tracks error rates to trigger circuit breaker activation

Configuration Management

Production models require flexible configuration for different environments:

#[derive(Debug, Clone)]
struct ModelConfig {
    // Model parameters
    temperature_prior_mean: f64,
    temperature_prior_std: f64,
    validity_probability: f64,
    sensor_noise_std: f64,

    // Runtime configuration
    max_inference_time_ms: u64,
    memory_pool_size: usize,
    enable_circuit_breaker: bool,
    error_threshold: u32,

    // Environment settings
    environment: String, // "development", "staging", "production"
    _log_level: String,
    enable_metrics: bool,
}

impl Default for ModelConfig {
    fn default() -> Self {
        Self {
            temperature_prior_mean: 20.0,
            temperature_prior_std: 5.0,
            validity_probability: 0.95,
            sensor_noise_std: 1.0,
            max_inference_time_ms: 1000,
            memory_pool_size: 100,
            enable_circuit_breaker: true,
            error_threshold: 10,
            environment: "production".to_string(),
            _log_level: "info".to_string(),
            enable_metrics: true,
        }
    }
}

struct ConfigurableModelRunner {
    config: ModelConfig,
    pool: TracePool,
    metrics: ProductionMetrics,
}

impl ConfigurableModelRunner {
    fn new(config: ModelConfig) -> Self {
        Self {
            pool: TracePool::new(config.memory_pool_size),
            metrics: ProductionMetrics::new(config.enable_metrics),
            config,
        }
    }

    fn create_model(&self) -> Model<(f64, bool)> {
        let config = self.config.clone();
        prob!(
            let temp <- sample(
                addr!("temperature"),
                Normal::new(config.temperature_prior_mean, config.temperature_prior_std).unwrap()
            );
            let valid <- sample(
                addr!("valid"),
                Bernoulli::new(config.validity_probability).unwrap()
            );
            // Simulate sensor reading with configured noise
            observe(
                addr!("sensor"),
                Normal::new(temp, config.sensor_noise_std).unwrap(),
                22.0
            );
            pure((temp, valid))
        )
    }

    fn run_inference(&mut self) -> Result<(f64, bool), String> {
        let start = Instant::now();

        // Configure handler based on environment
        let mut rng = thread_rng();
        let model = self.create_model(); // Create model before borrowing
        let result = if self.config.environment == "production" {
            // Use safe, fault-tolerant execution in production
            let base_handler = PooledPriorHandler::new(&mut rng, &mut self.pool);
            let robust_handler =
                RobustProductionHandler::new(base_handler, self.config.error_threshold);

            let (result, _trace) = runtime::handler::run(robust_handler, model);
            Ok(result)
        } else {
            // Use faster, less safe execution in development
            let handler = PriorHandler {
                rng: &mut rng,
                trace: Trace::default(),
            };
            let (result, _trace) = runtime::handler::run(handler, model);
            Ok(result)
        };

        let duration = start.elapsed();
        self.metrics.record_inference_time(duration);

        // Check timeout
        if duration.as_millis() > self.config.max_inference_time_ms as u128 {
            self.metrics.increment_timeout_count();
            return Err(format!(
                "Inference timeout: {}ms > {}ms",
                duration.as_millis(),
                self.config.max_inference_time_ms
            ));
        }

        result
    }
}

Configuration Best Practices:

  • Environment-Specific Settings: Different behavior for development/staging/production
  • Model Parameter Configuration: Tunable priors, noise levels, and thresholds
  • Runtime Configuration: Memory pool sizes, timeout limits, error thresholds
  • Deployment Configuration: Circuit breaker settings, logging levels, metrics enablement
  • Type-Safe Defaults: Sensible fallbacks for all configuration parameters

Production Metrics and Observability

Observability requires systematic metric collection with statistical analysis and anomaly detection. The metric taxonomy follows the USE method (Utilization, Saturation, Errors) and RED method (Rate, Errors, Duration):

graph TD
    subgraph "Observability Architecture"
        A[Model Execution] --> B[Metric Collection]
        B --> C{Metric Type}
        C -->|USE| D[Resource Metrics]
        C -->|RED| E[Service Metrics]  
        C -->|Business| F[Domain Metrics]
        
        D --> G[Utilization: ฯ = ฮป/ฮผ]
        D --> H[Saturation: Queue Length]
        D --> I[Error Rate: ฮปโ‚‘]
        
        E --> J[Request Rate: ฮปแตฃ]
        E --> K[Error Rate: ฮตแตฃ] 
        E --> L[Duration: Tโ‚‰โ‚‰]
        
        F --> M[Inference Accuracy]
        F --> N[Model Drift]
        F --> O[Business KPIs]
        
        G --> P[(Time Series DB)]
        H --> P
        I --> P
        J --> P
        K --> P
        L --> P
        M --> Q[(Analytics DB)]
        N --> Q
        O --> Q
    end

Statistical Process Control: Metrics follow control chart theory with statistical control limits:

Anomaly Detection: Using exponentially weighted moving averages:

#[derive(Debug, Clone)]
struct ProductionMetrics {
    enabled: bool,
    inference_count: u64,
    error_count: u64,
    timeout_count: u64,
    total_inference_time: Duration,
    start_time: SystemTime,
}

impl ProductionMetrics {
    fn new(enabled: bool) -> Self {
        Self {
            enabled,
            inference_count: 0,
            error_count: 0,
            timeout_count: 0,
            total_inference_time: Duration::ZERO,
            start_time: SystemTime::now(),
        }
    }

    fn record_inference_time(&mut self, duration: Duration) {
        if self.enabled {
            self.inference_count += 1;
            self.total_inference_time += duration;
        }
    }

    fn _increment_error_count(&mut self) {
        if self.enabled {
            self.error_count += 1;
        }
    }

    fn increment_timeout_count(&mut self) {
        if self.enabled {
            self.timeout_count += 1;
        }
    }

    fn get_stats(&self) -> HashMap<String, f64> {
        let mut stats = HashMap::new();
        if self.enabled {
            let uptime = self.start_time.elapsed().unwrap_or(Duration::ZERO);
            let avg_inference_time = if self.inference_count > 0 {
                self.total_inference_time.as_millis() as f64 / self.inference_count as f64
            } else {
                0.0
            };

            stats.insert("inference_count".to_string(), self.inference_count as f64);
            stats.insert("error_count".to_string(), self.error_count as f64);
            stats.insert("timeout_count".to_string(), self.timeout_count as f64);
            stats.insert(
                "error_rate".to_string(),
                if self.inference_count > 0 {
                    self.error_count as f64 / self.inference_count as f64
                } else {
                    0.0
                },
            );
            stats.insert("avg_inference_time_ms".to_string(), avg_inference_time);
            stats.insert("uptime_seconds".to_string(), uptime.as_secs() as f64);
            stats.insert(
                "throughput_per_second".to_string(),
                if uptime.as_secs() > 0 {
                    self.inference_count as f64 / uptime.as_secs() as f64
                } else {
                    0.0
                },
            );
        }
        stats
    }

    fn export_prometheus_metrics(&self) -> String {
        let mut metrics = String::new();
        let stats = self.get_stats();

        metrics.push_str("# HELP fugue_inference_total Total number of inference runs\n");
        metrics.push_str("# TYPE fugue_inference_total counter\n");
        metrics.push_str(&format!(
            "fugue_inference_total {}\n",
            stats.get("inference_count").unwrap_or(&0.0)
        ));

        metrics.push_str("# HELP fugue_errors_total Total number of errors\n");
        metrics.push_str("# TYPE fugue_errors_total counter\n");
        metrics.push_str(&format!(
            "fugue_errors_total {}\n",
            stats.get("error_count").unwrap_or(&0.0)
        ));

        metrics.push_str(
            "# HELP fugue_inference_duration_ms Average inference duration in milliseconds\n",
        );
        metrics.push_str("# TYPE fugue_inference_duration_ms gauge\n");
        metrics.push_str(&format!(
            "fugue_inference_duration_ms {}\n",
            stats.get("avg_inference_time_ms").unwrap_or(&0.0)
        ));

        metrics.push_str("# HELP fugue_error_rate Error rate (errors/total inferences)\n");
        metrics.push_str("# TYPE fugue_error_rate gauge\n");
        metrics.push_str(&format!(
            "fugue_error_rate {}\n",
            stats.get("error_rate").unwrap_or(&0.0)
        ));

        metrics
    }
}

/// Production monitoring handler that integrates with metrics systems
struct MetricsHandler<H: Handler> {
    inner: H,
    metrics: Arc<std::sync::Mutex<ProductionMetrics>>,
    _model_name: String,
}

impl<H: Handler> MetricsHandler<H> {
    fn new(
        inner: H,
        metrics: Arc<std::sync::Mutex<ProductionMetrics>>,
        model_name: String,
    ) -> Self {
        Self {
            inner,
            metrics,
            _model_name: model_name,
        }
    }
}

impl<H: Handler> Handler for MetricsHandler<H> {
    fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
        let start = Instant::now();
        let result = self.inner.on_sample_f64(addr, dist);

        if let Ok(mut metrics) = self.metrics.lock() {
            metrics.record_inference_time(start.elapsed());
        }

        result
    }

    fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
        self.inner.on_sample_bool(addr, dist)
    }

    fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
        self.inner.on_sample_u64(addr, dist)
    }

    fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
        self.inner.on_sample_usize(addr, dist)
    }

    fn on_observe_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>, value: f64) {
        self.inner.on_observe_f64(addr, dist, value);
    }

    fn on_observe_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>, value: bool) {
        self.inner.on_observe_bool(addr, dist, value);
    }

    fn on_observe_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>, value: u64) {
        self.inner.on_observe_u64(addr, dist, value);
    }

    fn on_observe_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>, value: usize) {
        self.inner.on_observe_usize(addr, dist, value);
    }

    fn on_factor(&mut self, log_weight: f64) {
        self.inner.on_factor(log_weight);
    }

    fn finish(self) -> Trace {
        self.inner.finish()
    }
}

Metrics Collection:

  • Performance Metrics: Inference time, throughput, operation counts
  • Error Tracking: Error rates, timeout counts, failure categorization
  • System Health: Uptime, resource utilization, memory pool efficiency
  • Prometheus Integration: Standard metrics format for monitoring systems
  • Real-Time Dashboards: Live performance and health indicators

Health Checks and System Validation

Health monitoring implements continuous system validation through multi-level health checks with statistical thresholds and predictive alerting:

graph TD
    subgraph "Health Check Hierarchy"
        A[System Health Hโฝหขโพ] --> B[Model Health Hโฝแตโพ]
        B --> C[Inference Health Hโฝโฑโพ]
        C --> D[Resource Health Hโฝสณโพ]

        A --> E{Hโฝหขโพ > ฮธโ‚›?}
        B --> F{Hโฝแตโพ > ฮธโ‚˜?}
        C --> G{Hโฝโฑโพ > ฮธแตข?}
        D --> H{Hโฝสณโพ > ฮธสณ?}

        E -->|No| I[System Alert]
        F -->|No| J[Model Alert]
        G -->|No| K[Inference Alert]
        H -->|No| L[Resource Alert]

        E -->|Yes| M[System OK]
        F -->|Yes| M
        G -->|Yes| M
        H -->|Yes| M
    end

Health Score Calculation: Weighted combination of subsystem health:

where are importance weights and .

Predictive Health Modeling: Using time series forecasting:

Health Degradation Alert

Early Warning System: When for sustained periods, the system triggers preemptive scaling or graceful degradation before reaching critical thresholds.

#[derive(Debug, Clone)]
struct HealthCheckResult {
    status: HealthStatus,
    message: String,
    details: HashMap<String, String>,
    timestamp: SystemTime,
}

#[derive(Debug, Clone, PartialEq)]
enum HealthStatus {
    Healthy,
    Degraded,
    Unhealthy,
}

struct ProductionHealthChecker {
    model_config: ModelConfig,
    metrics: Arc<std::sync::Mutex<ProductionMetrics>>,
}

impl ProductionHealthChecker {
    fn new(config: ModelConfig, metrics: Arc<std::sync::Mutex<ProductionMetrics>>) -> Self {
        Self {
            model_config: config,
            metrics,
        }
    }

    fn run_health_check(&self) -> HealthCheckResult {
        let mut details = HashMap::new();
        let mut overall_status = HealthStatus::Healthy;
        let mut messages = Vec::new();

        // Check 1: Model execution health
        match self.check_model_execution() {
            Ok(duration) => {
                details.insert("model_execution".to_string(), "healthy".to_string());
                details.insert(
                    "execution_time_ms".to_string(),
                    format!("{:.1}", duration.as_millis()),
                );
            }
            Err(e) => {
                overall_status = HealthStatus::Unhealthy;
                messages.push(format!("Model execution failed: {}", e));
                details.insert("model_execution".to_string(), "failed".to_string());
            }
        }

        // Check 2: Memory usage
        if let Some(pool_stats) = self.check_memory_health() {
            let hit_ratio = pool_stats.hit_ratio();
            details.insert("memory_hit_ratio".to_string(), format!("{:.2}%", hit_ratio));

            if hit_ratio < 50.0 {
                overall_status = HealthStatus::Degraded;
                messages.push("Low memory pool hit ratio".to_string());
            }
        }

        // Check 3: Error rates
        if let Ok(metrics) = self.metrics.lock() {
            let stats = metrics.get_stats();
            let error_rate = stats.get("error_rate").unwrap_or(&0.0) * 100.0;
            details.insert(
                "error_rate_percent".to_string(),
                format!("{:.2}%", error_rate),
            );

            if error_rate > 5.0 {
                overall_status = HealthStatus::Degraded;
                messages.push(format!("High error rate: {:.1}%", error_rate));
            } else if error_rate > 20.0 {
                overall_status = HealthStatus::Unhealthy;
                messages.push(format!("Critical error rate: {:.1}%", error_rate));
            }

            let avg_time = stats.get("avg_inference_time_ms").unwrap_or(&0.0);
            details.insert(
                "avg_inference_time_ms".to_string(),
                format!("{:.1}", avg_time),
            );

            if *avg_time > self.model_config.max_inference_time_ms as f64 * 0.8 {
                overall_status = HealthStatus::Degraded;
                messages.push("Inference time approaching timeout threshold".to_string());
            }
        }

        // Check 4: System resources
        details.insert(
            "memory_pool_size".to_string(),
            self.model_config.memory_pool_size.to_string(),
        );
        details.insert(
            "circuit_breaker".to_string(),
            if self.model_config.enable_circuit_breaker {
                "enabled".to_string()
            } else {
                "disabled".to_string()
            },
        );

        let message = if messages.is_empty() {
            "All systems healthy".to_string()
        } else {
            messages.join("; ")
        };

        HealthCheckResult {
            status: overall_status,
            message,
            details,
            timestamp: SystemTime::now(),
        }
    }

    fn check_model_execution(&self) -> Result<Duration, String> {
        let start = Instant::now();
        let mut rng = thread_rng();

        // Run a simplified version of the model for health checking
        let health_model = || {
            prob!(
                let value <- sample(addr!("health_check"), Normal::new(0.0, 1.0).unwrap());
                pure(value)
            )
        };

        let handler = PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        };

        match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
            runtime::handler::run(handler, health_model())
        })) {
            Ok((result, trace)) => {
                if result.is_finite() && trace.total_log_weight().is_finite() {
                    Ok(start.elapsed())
                } else {
                    Err("Invalid model output".to_string())
                }
            }
            Err(_) => Err("Model execution panicked".to_string()),
        }
    }

    fn check_memory_health(&self) -> Option<fugue::runtime::memory::PoolStats> {
        // In a real implementation, this would check the actual memory pool
        // For demonstration, we'll create a temporary pool
        let pool = TracePool::new(10);
        Some(pool.stats().clone())
    }
}

Health Check Components:

  • Model Execution Health: Verifies core functionality with simplified tests
  • Memory Health: Monitors pool efficiency and memory usage patterns
  • Error Rate Analysis: Tracks and categorizes different failure modes
  • Performance Monitoring: Identifies degradation before it impacts users
  • Multi-Level Status: Healthy/Degraded/Unhealthy with detailed diagnostics

Input Validation and Security

Robust input validation prevents security vulnerabilities and system failures:

/// Secure input validator for production model parameters
struct InputValidator;

impl InputValidator {
    fn validate_temperature(temp: f64) -> Result<f64, String> {
        match temp {
            t if !t.is_finite() => Err("Temperature must be finite".to_string()),
            t if t < -50.0 => Err("Temperature too low (< -50ยฐC)".to_string()),
            t if t > 100.0 => Err("Temperature too high (> 100ยฐC)".to_string()),
            t => Ok(t),
        }
    }

    fn validate_probability(p: f64) -> Result<f64, String> {
        match p {
            p if !p.is_finite() => Err("Probability must be finite".to_string()),
            p if p < 0.0 => Err("Probability must be non-negative".to_string()),
            p if p > 1.0 => Err("Probability must not exceed 1.0".to_string()),
            p => Ok(p),
        }
    }

    fn _validate_sensor_reading(reading: f64) -> Result<f64, String> {
        match reading {
            r if !r.is_finite() => Err("Sensor reading must be finite".to_string()),
            r if r.abs() > 1000.0 => Err("Sensor reading out of reasonable range".to_string()),
            r => Ok(r),
        }
    }

    fn sanitize_address_component(component: &str) -> Result<String, String> {
        // Prevent injection attacks in address components
        if component
            .chars()
            .any(|c| !(c.is_alphanumeric() || c == '_' || c == '-'))
        {
            return Err("Address component contains invalid characters".to_string());
        }

        if component.len() > 50 {
            Err("Address component too long".to_string())
        } else if component.is_empty() {
            Err("Address component cannot be empty".to_string())
        } else {
            Ok(component.to_string())
        }
    }
}

/// Production model with comprehensive input validation
fn create_validated_model(
    temperature_reading: f64,
    sensor_id: &str,
    prior_prob: f64,
) -> Result<Model<(f64, bool)>, String> {
    // Validate all inputs before model creation
    let validated_temp = InputValidator::validate_temperature(temperature_reading)?;
    let validated_prob = InputValidator::validate_probability(prior_prob)?;
    let sanitized_sensor_id = InputValidator::sanitize_address_component(sensor_id)?;

    // Additional business logic validation
    if sanitized_sensor_id.starts_with("test_") && validated_prob > 0.5 {
        return Err("Test sensors cannot have high prior probability".to_string());
    }

    Ok(prob!(
        let true_temp <- sample(
            addr!("temperature"),
            Normal::new(validated_temp, 2.0).unwrap()
        );
        let is_working <- sample(
            addr!("sensor_working", sanitized_sensor_id.clone()),
            Bernoulli::new(validated_prob).unwrap()
        );

        // Safe observation with validated input
        observe(
            addr!("reading", sanitized_sensor_id),
            Normal::new(true_temp, if is_working { 0.5 } else { 5.0 }).unwrap(),
            validated_temp
        );

        pure((true_temp, is_working))
    ))
}

Security Measures:

  • Range Validation: Ensure parameters are within physically meaningful bounds
  • Type Safety: Validate all inputs before model construction
  • Sanitization: Clean address components to prevent injection attacks
  • Business Rule Enforcement: Domain-specific validation logic
  • Error Messages: Informative feedback without revealing system internals

Deployment Strategies and Patterns

Deployment strategies implement risk management through controlled rollout and statistical validation. Each strategy provides different risk-latency tradeoffs:

graph TD
    subgraph "Deployment Strategy Matrix"  
        A[New Model Version] --> B{Strategy Selection}
        B -->|Low Risk| C[Blue-Green]
        B -->|Medium Risk| D[Canary]
        B -->|High Risk| E[A/B Test]

        C --> F[Instant Switch<br/>Risk: High<br/>Rollback: Fast]
        D --> G[Gradual Rollout<br/>Risk: Medium<br/>Validation: Statistical]
        E --> H[Statistical Test<br/>Risk: Low<br/>Duration: Long]

        F --> I{Success?}
        G --> J{Performance > Baseline?}
        E --> K{Significance Test?}

        I -->|No| L[Instant Rollback]
        J -->|No| M[Gradual Rollback]
        K -->|No| N[Maintain Status Quo]

        I -->|Yes| O[Full Deployment]
        J -->|Yes| P[Continue Rollout]
        K -->|Yes| Q[Gradual Migration]
    end

Canary Analysis: Statistical significance testing for canary deployments:

A/B Testing: Welch's t-test for unequal variances:

/// Production deployment manager with different strategies
#[derive(Debug, Clone)]
enum DeploymentStrategy {
    BlueGreen,
    CanaryRelease { percentage: f64 },
    RollingUpdate,
    _ImmediateSwitch,
}

struct ModelDeploymentManager {
    current_model_version: String,
    candidate_model_version: String,
    deployment_strategy: DeploymentStrategy,
    _rollback_threshold_error_rate: f64,
}

impl ModelDeploymentManager {
    fn new(strategy: DeploymentStrategy) -> Self {
        Self {
            current_model_version: "v1.0.0".to_string(),
            candidate_model_version: "v1.1.0".to_string(),
            deployment_strategy: strategy,
            _rollback_threshold_error_rate: 0.05, // 5% error rate triggers rollback
        }
    }

    fn should_use_candidate_model(&self, request_id: u64) -> bool {
        match &self.deployment_strategy {
            DeploymentStrategy::BlueGreen => {
                // In blue-green, we typically switch all traffic at once
                // For demo, we'll use request ID to simulate the switch
                request_id % 100 < 10 // 10% to candidate for testing
            }
            DeploymentStrategy::CanaryRelease { percentage } => {
                let hash = request_id % 100;
                (hash as f64) < (*percentage * 100.0)
            }
            DeploymentStrategy::RollingUpdate => {
                // Gradual rollout based on some criteria
                request_id % 10 < 3 // 30% rollout
            }
            DeploymentStrategy::_ImmediateSwitch => true,
        }
    }

    fn create_model(&self, use_candidate: bool) -> impl Fn() -> Model<f64> {
        let version = if use_candidate {
            self.candidate_model_version.clone()
        } else {
            self.current_model_version.clone()
        };

        move || {
            if version.starts_with("v1.1") {
                // Candidate model with improved parameters
                prob!(
                    let value <- sample(addr!("improved_param"), Normal::new(0.0, 0.8).unwrap());
                    factor(0.1); // Slight preference for this model
                    pure(value)
                )
            } else {
                // Current stable model
                prob!(
                    let value <- sample(addr!("stable_param"), Normal::new(0.0, 1.0).unwrap());
                    pure(value)
                )
            }
        }
    }

    fn process_request(&self, request_id: u64) -> Result<(f64, String), String> {
        let use_candidate = self.should_use_candidate_model(request_id);
        let version = if use_candidate {
            &self.candidate_model_version
        } else {
            &self.current_model_version
        };

        let model = self.create_model(use_candidate);

        // Execute with error handling
        let mut rng = thread_rng();
        let handler = PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        };

        match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
            runtime::handler::run(handler, model())
        })) {
            Ok((result, trace)) => {
                if result.is_finite() && trace.total_log_weight().is_finite() {
                    Ok((result, version.clone()))
                } else {
                    Err(format!("Invalid result from model {}", version))
                }
            }
            Err(_) => Err(format!("Model {} panicked", version)),
        }
    }
}

Deployment Patterns:

  • Blue-Green Deployment: Instant traffic switching between model versions
  • Canary Releases: Gradual rollout to percentage of traffic for risk mitigation
  • Rolling Updates: Progressive deployment across infrastructure
  • A/B Testing: Compare model performance with statistical significance
  • Rollback Capability: Quick reversion to previous version on issues

Performance Optimization Patterns

Memory Management

use fugue::runtime::memory::{TracePool, PooledPriorHandler};

// Production memory management
let mut pool = TracePool::new(1000);
let handler = PooledPriorHandler::new(&mut rng, &mut pool);

Batch Processing

// Process multiple inference requests efficiently
struct BatchProcessor {
    pool: TracePool,
    batch_size: usize,
}

impl BatchProcessor {
    fn process_batch(&mut self, requests: Vec<InferenceRequest>) -> Vec<InferenceResult> {
        requests.into_iter().map(|req| {
            let handler = PooledPriorHandler::new(&mut req.rng, &mut self.pool);
            self.run_single_inference(handler, req.model)
        }).collect()
    }
}

Connection Pooling

// Database connection management for model parameters
struct ModelParameterStore {
    connection_pool: Arc<ConnectionPool>,
    parameter_cache: LruCache<String, ModelParameters>,
}

Monitoring Integration

Prometheus Metrics

// Export metrics in Prometheus format
fn export_metrics(metrics: &ProductionMetrics) -> String {
    format!(
        "# HELP fugue_inference_total Total inference operations\n\
         TYPE fugue_inference_total counter\n\
         fugue_inference_total {}\n\
         HELP fugue_error_rate Current error rate\n\
         TYPE fugue_error_rate gauge\n\
         fugue_error_rate {}\n",
        metrics.inference_count,
        metrics.error_rate()
    )
}

Structured Logging

use serde_json::json;

// Structured logging for production debugging
fn log_inference_event(
    request_id: &str,
    model_version: &str,
    duration: Duration,
    result: &InferenceResult
) {
    let log_entry = json!({
        "event": "inference_completed",
        "request_id": request_id,
        "model_version": model_version,
        "duration_ms": duration.as_millis(),
        "success": result.is_success(),
        "timestamp": SystemTime::now(),
    });
    println!("{}", log_entry);
}

Alert Rules

// Define alerting thresholds
struct AlertRules {
    max_error_rate: f64,
    max_latency_ms: u64,
    min_throughput_per_sec: f64,
}

impl AlertRules {
    fn check_alerts(&self, metrics: &ProductionMetrics) -> Vec<Alert> {
        let mut alerts = Vec::new();

        if metrics.error_rate() > self.max_error_rate {
            alerts.push(Alert::HighErrorRate(metrics.error_rate()));
        }

        if metrics.avg_latency().as_millis() > self.max_latency_ms as u128 {
            alerts.push(Alert::HighLatency(metrics.avg_latency()));
        }

        alerts
    }
}

Testing in Production

Shadow Mode Testing

// Run new model versions in shadow mode
struct ShadowTester {
    primary_model: Box<dyn Fn() -> Model<f64>>,
    shadow_model: Box<dyn Fn() -> Model<f64>>,
    comparison_rate: f64,
}

impl ShadowTester {
    fn run_with_shadow(&mut self, input: &Input) -> (PrimaryResult, Option<ShadowResult>) {
        let primary = self.run_primary(input);

        let shadow = if rand::random::<f64>() < self.comparison_rate {
            Some(self.run_shadow(input))
        } else {
            None
        };

        (primary, shadow)
    }
}

Production Validation

// Continuous validation in production
fn validate_model_assumptions(trace: &Trace) -> ValidationResult {
    let mut issues = Vec::new();

    // Check log-weight stability
    if !trace.total_log_weight().is_finite() {
        issues.push("Non-finite log-weight detected".to_string());
    }

    // Check parameter ranges
    for (addr, choice) in &trace.choices {
        if let ChoiceValue::F64(value) = choice.value {
            if value.abs() > 1000.0 {
                issues.push(format!("Extreme value at {}: {}", addr, value));
            }
        }
    }

    ValidationResult { issues }
}

Operational Excellence

Infrastructure as Code

# Kubernetes deployment example
apiVersion: apps/v1
kind: Deployment
metadata:
  name: fugue-inference-service
spec:
  replicas: 3
  selector:
    matchLabels:
      app: fugue-inference
  template:
    metadata:
      labels:
        app: fugue-inference
    spec:
      containers:
      - name: inference-service
        image: fugue-inference:v1.2.0
        resources:
          requests:
            memory: "256Mi"
            cpu: "250m"
          limits:
            memory: "512Mi"
            cpu: "500m"
        livenessProbe:
          httpGet:
            path: /health
            port: 8080
          initialDelaySeconds: 30
          periodSeconds: 10

Service Level Objectives (SLOs)

// Define and monitor SLOs
struct ServiceLevelObjectives {
    availability_target: f64,    // 99.9%
    latency_p99_ms: u64,        // 100ms
    error_rate_threshold: f64,   // 0.1%
}

impl ServiceLevelObjectives {
    fn evaluate_slo_compliance(&self, metrics: &ProductionMetrics) -> SLOReport {
        SLOReport {
            availability: self.calculate_availability(metrics),
            latency_compliance: metrics.p99_latency() <= Duration::from_millis(self.latency_p99_ms),
            error_rate_compliance: metrics.error_rate() <= self.error_rate_threshold,
        }
    }
}

Capacity Planning

// Capacity planning and auto-scaling
struct CapacityPlanner {
    target_cpu_utilization: f64,
    target_memory_utilization: f64,
    scale_up_threshold: f64,
    scale_down_threshold: f64,
}

impl CapacityPlanner {
    fn recommend_scaling(&self, current_metrics: &SystemMetrics) -> ScalingRecommendation {
        if current_metrics.cpu_utilization > self.scale_up_threshold {
            ScalingRecommendation::ScaleUp(self.calculate_scale_factor(current_metrics))
        } else if current_metrics.cpu_utilization < self.scale_down_threshold {
            ScalingRecommendation::ScaleDown(0.5)
        } else {
            ScalingRecommendation::NoAction
        }
    }
}

Security Best Practices

Input Sanitization

Always validate and sanitize inputs before processing:

  • Range checks for numerical parameters
  • Character filtering for string inputs
  • Business rule validation for domain constraints
  • Rate limiting to prevent abuse
  • Authentication and authorization for API access

Secret Management

// Secure configuration management
struct SecureConfig {
    database_url: SecretString,
    api_key: SecretString,
    model_parameters: ModelConfig,
}

impl SecureConfig {
    fn from_environment() -> Result<Self, ConfigError> {
        Ok(SecureConfig {
            database_url: env::var("DATABASE_URL")?.into(),
            api_key: env::var("API_KEY")?.into(),
            model_parameters: ModelConfig::from_file("model_config.toml")?,
        })
    }
}

Audit Logging

// Comprehensive audit trail
fn log_inference_request(
    user_id: &str,
    request: &InferenceRequest,
    response: &InferenceResponse,
) {
    let audit_log = AuditLogEntry {
        timestamp: SystemTime::now(),
        user_id: user_id.to_string(),
        action: "inference_request".to_string(),
        input_hash: hash_sensitive_data(&request.input),
        output_hash: hash_sensitive_data(&response.output),
        model_version: response.model_version.clone(),
        success: response.success,
    };

    audit_logger::log(audit_log);
}

Common Production Pitfalls

Memory Leaks

// Avoid: Creating new pools repeatedly
// for _ in 0..1000 {
//     let pool = TracePool::new(100); // Memory leak!
// }

// Do: Reuse pools across requests
let mut pool = TracePool::new(100);
for request in requests {
    let handler = PooledPriorHandler::new(&mut request.rng, &mut pool);
    process_request(handler, request);
}

Blocking Operations

// Avoid: Synchronous database calls in request handlers
// let result = database.query_sync(query); // Blocks event loop

// Do: Use async operations with proper timeouts
async fn process_request(request: Request) -> Result<Response, Error> {
    let timeout = Duration::from_millis(100);
    let result = tokio::time::timeout(timeout, database.query(query)).await??;
    Ok(result)
}

Error Propagation

// Avoid: Panicking on errors
// let value = risky_operation().unwrap(); // May crash service

// Do: Graceful error handling with fallbacks
let value = match risky_operation() {
    Ok(v) => v,
    Err(e) => {
        metrics.increment_error_count();
        log::warn!("Operation failed: {}, using fallback", e);
        fallback_value()
    }
};

Production Excellence Framework

Successful production deployment combines mathematical rigor with engineering excellence:

  1. Reliability Engineering: Fault tolerance through statistical modeling and circuit breaker patterns
  2. Performance Optimization: Memory pooling, numerical stability, and batch processing
  3. Observability: Multi-dimensional metrics with statistical process control
  4. Deployment Strategies: Risk-managed rollouts with statistical validation
  5. Health Monitoring: Predictive alerting and graceful degradation

These patterns enable robust production systems capable of handling real-world probabilistic computing at scale.

Production deployment represents the culmination of probabilistic programming excellence, where theoretical foundations meet operational reality. Fugue's comprehensive tooling transforms academic probabilistic models into production-grade systems that deliver reliable, scalable, and maintainable probabilistic computing solutions.

Tutorials

Foundation Tutorials

Welcome to Fugue's Foundation Tutorials โ€” your comprehensive introduction to the core concepts and unique features that make Fugue a revolutionary approach to probabilistic programming.

What You'll Learn

These tutorials build upon each other to give you a complete understanding of Fugue's foundational principles:

graph TB
    A["๐Ÿช™ Bayesian Coin Flip<br/>Basic Probabilistic Modeling"] --> B["๐Ÿ”’ Type Safety Features<br/>Fugue's Type System Advantages"]
    B --> C["๐Ÿ” Trace Manipulation<br/>Runtime System & Custom Inference"]
    
    A --> D["Statistical Foundations"]
    B --> E["Type-Safe Programming"] 
    C --> F["Advanced Inference"]
    
    D --> G["Production Ready<br/>Probabilistic Models"]
    E --> G
    F --> G
    
    style A fill:#e1f5fe
    style B fill:#f3e5f5  
    style C fill:#e8f5e8
    style G fill:#fff3e0

Learning Path

  1. Bayesian Coin Flip (~45 minutes)

    • Start here for statistical foundations
    • Learn Bayesian inference principles
    • Understand model specification and analysis
  2. Type Safety Features (~30 minutes)

    • Discover Fugue's unique advantages
    • Master type-safe probabilistic programming
    • Eliminate runtime errors with compile-time guarantees
  3. Trace Manipulation (~60 minutes)

    • Deep dive into Fugue's runtime system
    • Learn custom inference and debugging techniques
    • Build production-ready probabilistic applications

Tutorial Overview

๐Ÿช™ Bayesian Coin Flip

Foundation: Statistical inference and model analysis

Your introduction to Bayesian reasoning through the classic coin flipping problem. This tutorial demonstrates how prior beliefs are updated with evidence to form posterior distributions.

Key Concepts:

  • Prior, likelihood, and posterior distributions
  • Bayesian updating with Beta-Binomial conjugacy
  • Model validation and parameter estimation
  • Analytical vs computational solutions

What You'll Build:

  • Complete Bayesian coin bias estimation model
  • Prior sensitivity analysis framework
  • Model validation with synthetic data

Prerequisites

Basic probability theory (distributions, Bayes' theorem)


๐Ÿ”’ Type Safety Features

Foundation: Type-safe probabilistic programming

Explore Fugue's revolutionary type system that eliminates runtime errors while preserving full statistical expressiveness. Learn how dependent types make probabilistic programs both safer and faster.

Key Concepts:

  • Natural return types for distributions (bool, u64, f64, usize)
  • Compile-time safety guarantees
  • Safe array indexing with categorical distributions
  • Parameter validation at construction time
  • Performance benefits through zero-cost abstractions

What You'll Build:

  • Type-safe hierarchical models
  • Safe array indexing examples
  • Performance comparison with traditional PPLs

Why This Matters

Traditional PPLs force everything through f64, leading to runtime casting and errors. Fugue's type system catches these issues at compile time.


๐Ÿ” Trace Manipulation

Foundation: Runtime system and advanced inference

Master Fugue's execution trace system โ€” the foundation that enables sophisticated inference algorithms. Learn how traces record, replay, and analyze probabilistic model executions.

Key Concepts:

  • Trace system architecture and execution history
  • Handler system for flexible model interpretation
  • Replay mechanics for MCMC and inference algorithms
  • Custom handlers for specialized inference strategies
  • Memory optimization for production deployment
  • Diagnostic tools for convergence assessment

What You'll Build:

  • Custom MCMC algorithm using trace replay
  • Specialized handlers for debugging models
  • Production inference pipeline with memory optimization
  • Comprehensive diagnostic system for model validation

Advanced Content

This tutorial covers sophisticated concepts. Complete the previous tutorials first.

Learning Outcomes

After completing these foundation tutorials, you will:

๐Ÿ“Š Statistical Mastery

  • โœ… Understand Bayesian inference from first principles
  • โœ… Build and validate probabilistic models confidently
  • โœ… Interpret posterior distributions and uncertainty quantification
  • โœ… Apply conjugate analysis and computational methods

๐Ÿ›ก๏ธ Type-Safe Programming

  • โœ… Write probabilistic programs that catch errors at compile time
  • โœ… Leverage natural return types for cleaner, safer code
  • โœ… Understand performance benefits of zero-cost abstractions
  • โœ… Build complex models with guaranteed type safety

โš™๏ธ Advanced Inference

  • โœ… Manipulate execution traces for custom inference algorithms
  • โœ… Implement specialized handlers for unique requirements
  • โœ… Debug and optimize problematic models systematically
  • โœ… Deploy production-ready probabilistic systems

๐Ÿ”ง Production Skills

  • โœ… Memory optimization techniques for high-throughput scenarios
  • โœ… Convergence diagnostics and model validation workflows
  • โœ… Custom inference algorithms tailored to specific problems
  • โœ… Systematic debugging of numerical issues

Code Examples

All tutorials include comprehensive, tested code examples:

Each example includes:

  • โœ… Comprehensive tests ensuring correctness
  • โœ… Detailed comments explaining every concept
  • โœ… Runnable code you can execute immediately
  • โœ… Performance benchmarks where applicable
# Run any example to see concepts in action
cargo run --example bayesian_coin_flip
cargo run --example type_safety  
cargo run --example trace_manipulation

# Run tests to verify your understanding
cargo test --example bayesian_coin_flip
cargo test --example type_safety
cargo test --example trace_manipulation

Next Steps

After mastering these foundations, you're ready for:

๐Ÿ“ˆ Statistical Modeling Tutorials

Apply your knowledge to real-world problems:

  • Linear and logistic regression
  • Hierarchical models and mixed effects
  • Mixture models and clustering
  • Time series and forecasting

๐Ÿ—๏ธ How-To Guides

Practical guidance for specific tasks:

๐Ÿš€ Advanced Applications

Cutting-edge probabilistic programming:

  • Advanced inference techniques
  • Model comparison and selection
  • Large-scale distributed inference

Getting Help

๐Ÿ“š Documentation

๐Ÿ’ก Tips for Success

Learning Strategy

  1. Code Along: Don't just read โ€” run the examples and modify them
  2. Experiment: Change parameters and observe how results differ
  3. Test Understanding: Complete the exercises in each tutorial
  4. Apply Concepts: Try building your own models using the techniques

Common Pitfalls

  • Skipping mathematical foundations: The Bayesian coin flip tutorial builds essential intuition
  • Ignoring type safety: Fugue's type system prevents many subtle bugs
  • Not understanding traces: The execution history is key to advanced inference

๐Ÿ”ง Troubleshooting

If you encounter issues:

  1. Check prerequisites - Ensure you have the required mathematical background
  2. Run examples step-by-step - Isolate where confusion arises
  3. Review error messages - Fugue's type system provides helpful compile-time feedback
  4. Consult diagnostics - Use trace analysis to debug model behavior

Ready to Begin?

Start your journey with Bayesian Coin Flip โ€” the gateway to mastering probabilistic programming with Fugue.

Foundation Tutorials

These tutorials transform you from a probabilistic programming novice to someone who can build sophisticated, type-safe, production-ready Bayesian models. Each concept builds on the previous, creating a complete mental model of how Fugue works.

Time Investment: ~2.5 hours total
Skill Level: Beginner to Intermediate
Outcome: Complete foundation in modern probabilistic programming

Bayesian Coin Flip

A comprehensive introduction to Bayesian inference through the classic coin flip problem. This tutorial demonstrates core Bayesian concepts including prior beliefs, likelihood functions, posterior distributions, and conjugate analysis using Fugue's type-safe probabilistic programming framework.

Learning Objectives

By the end of this tutorial, you will understand:

  • Bayesian Inference: How to combine prior beliefs with data
  • Conjugate Analysis: Analytical solutions for Beta-Bernoulli models
  • Model Validation: Posterior predictive checks and diagnostics
  • Decision Theory: Making practical decisions under uncertainty

The Problem & Data

Research Question: Is a coin fair, or does it have a bias toward heads or tails?

In classical statistics, we might perform a hypothesis test. In Bayesian statistics, we express our uncertainty about the coin's bias as a probability distribution and update this belief as we observe data.

Why Coin Flips Matter

The coin flip problem is fundamental because it introduces all core Bayesian concepts in their simplest form:

  • Binary outcomes (success/failure, true/false) appear everywhere in practice
  • Beta-Bernoulli conjugacy provides exact analytical solutions
  • Parameter uncertainty is naturally quantified through probability distributions

Data Generation & Exploration

    // Real experimental data: coin flip outcomes
    // H = Heads (success), T = Tails (failure)
    let observed_flips = vec![
        true, false, true, true, false, true, true, false, true, true,
    ];
    let n_flips = observed_flips.len();
    let successes = observed_flips.iter().filter(|&&x| x).count();

    println!("๐Ÿช™ Observed coin flip sequence:");
    for (i, &flip) in observed_flips.iter().enumerate() {
        print!("  Flip {}: {}", i + 1, if flip { "H" } else { "T" });
        if (i + 1) % 5 == 0 {
            println!();
        }
    }
    println!(
        "\n๐Ÿ“Š Summary: {} successes out of {} flips ({:.1}%)",
        successes,
        n_flips,
        (successes as f64 / n_flips as f64) * 100.0
    );

    // Research question: Is this a fair coin? (p = 0.5)
    println!("โ“ Research Question: Is this coin fair (p = 0.5)?");

Real-World Context: This could represent:

  • Quality Control: Defective vs. non-defective products
  • Medical Trials: Treatment success rates
  • A/B Testing: Conversion rates between variants
  • Survey Response: Yes/No answers to questions

Mathematical Foundation

The Bayesian paradigm treats parameters as random variables with probability distributions. For the coin flip problem:

Model Specification

Prior Distribution: Our initial belief about the coin bias before seeing data.

where and encode our prior "pseudo-observations" of successes and failures.

Likelihood Function: Given bias , each flip follows:

The joint likelihood for independent flips with successes is:

Posterior Distribution: By Bayes' theorem:

For the Beta-Bernoulli model, the posterior is:

Conjugate Prior Theorem

The Beta distribution is conjugate to the Bernoulli likelihood, meaning:

  • Prior:
  • Likelihood: with successes
  • Posterior:

This gives us exact analytical solutions without requiring numerical approximation.

    // Bayesian Model Specification:
    // Prior: p ~ Beta(ฮฑโ‚€, ฮฒโ‚€)  [belief about coin bias before data]
    // Likelihood: X_i ~ Bernoulli(p)  [each flip outcome]
    // Posterior: p|data ~ Beta(ฮฑโ‚€ + successes, ฮฒโ‚€ + failures)

    // Prior parameters (weakly informative)
    let prior_alpha = 2.0_f64; // Prior "successes"
    let prior_beta = 2.0_f64; // Prior "failures"

    // Prior implies: E[p] = ฮฑ/(ฮฑ+ฮฒ) = 0.5, but allows uncertainty
    let prior_mean = prior_alpha / (prior_alpha + prior_beta);
    let prior_variance = (prior_alpha * prior_beta)
        / ((prior_alpha + prior_beta).powi(2_i32) * (prior_alpha + prior_beta + 1.0));

    println!(
        "๐Ÿ“ˆ Prior Distribution: Beta({}, {})",
        prior_alpha, prior_beta
    );
    println!("   - Prior mean: {:.3}", prior_mean);
    println!("   - Prior variance: {:.4}", prior_variance);
    println!("   - Interpretation: Weakly favors fairness but allows bias");

Prior Choice and Interpretation

The Beta() distribution has:

  • Mean:
  • Variance:

Common choices:

  • Beta(1,1): Uniform prior (no preference)
  • Beta(2,2): Weakly informative, slight preference for fairness
  • Beta(0.5, 0.5): Jeffreys prior (non-informative)

Basic Implementation

Let's implement our Bayesian coin flip model in Fugue:

use fugue::*;
// Define the probabilistic model
fn coin_flip_model(data: Vec<bool>) -> Model<f64> {
    prob!(
        // Prior belief about coin bias
        let p <- sample(addr!("coin_bias"), Beta::new(2.0, 2.0).unwrap());

        // Constrain p to valid range [0, 1] for numerical stability
        let p_constrained = p.clamp(1e-10, 1.0 - 1e-10);

        // Likelihood: observe each flip given the bias
        let _observations <- plate!(i in 0..data.len() => {
            observe(addr!("flip", i), Bernoulli::new(p_constrained).unwrap(), data[i])
        });

        // Return the inferred bias
        pure(p)
    )
}

Key Implementation Details:

  1. Type Safety: Bernoulli returns bool directlyโ€”no casting required
  2. Plate Notation: plate! efficiently handles vectorized observations
  3. Address Uniqueness: Each observation gets a unique address flip#i
  4. Pure Functional: Model returns the parameter of interest directly

Model Design Patterns

  • Use descriptive addresses like "coin_bias" instead of generic names
  • Plate notation scales efficiently to large datasets
  • Pure functions make models testable and composable
  • Type safety eliminates runtime errors from incorrect data types

Advanced Techniques

Analytical Posterior Solution

The beauty of conjugate priors is that we can compute the exact posterior without numerical approximation:

    // Beta-Bernoulli conjugacy gives exact posterior
    let posterior_alpha = prior_alpha + successes as f64;
    let posterior_beta = prior_beta + (n_flips - successes) as f64;

    let posterior_mean = posterior_alpha / (posterior_alpha + posterior_beta);
    let posterior_variance = (posterior_alpha * posterior_beta)
        / ((posterior_alpha + posterior_beta).powi(2) * (posterior_alpha + posterior_beta + 1.0));

    println!(
        "๐ŸŽฏ Analytical Posterior: Beta({:.0}, {:.0})",
        posterior_alpha, posterior_beta
    );
    println!("   - Posterior mean: {:.3}", posterior_mean);
    println!("   - Posterior variance: {:.4}", posterior_variance);

    // Credible intervals
    let posterior_dist = Beta::new(posterior_alpha, posterior_beta).unwrap();
    let _lower_bound = 0.025; // 2.5th percentile
    let _upper_bound = 0.975; // 97.5th percentile

    // Approximate quantiles (would need inverse CDF for exact)
    println!(
        "   - 95% credible interval: approximately [{:.2}, {:.2}]",
        posterior_mean - 1.96 * posterior_variance.sqrt(),
        posterior_mean + 1.96 * posterior_variance.sqrt()
    );

    // Hypothesis testing: P(p > 0.5 | data)
    let prob_biased_heads = if posterior_mean > 0.5 {
        0.8 // Rough approximation - would integrate Beta CDF for exact value
    } else {
        0.3
    };
    println!("   - P(p > 0.5 | data) โ‰ˆ {:.1}", prob_biased_heads);

MCMC Validation

While analytical solutions are preferred, we can validate our results using MCMC:

    // Use MCMC to approximate the posterior (for validation)
    let n_samples = 2000;
    let n_warmup = 500;

    let mut rng = rand::rngs::StdRng::seed_from_u64(42);
    let mcmc_samples = adaptive_mcmc_chain(
        &mut rng,
        || coin_flip_model(observed_flips.clone()),
        n_samples,
        n_warmup,
    );

    // Extract posterior samples for p
    let posterior_samples: Vec<f64> = mcmc_samples
        .iter()
        .filter_map(|(_, trace)| trace.get_f64(&addr!("coin_bias")))
        .collect();

    if !posterior_samples.is_empty() {
        let mcmc_mean = posterior_samples.iter().sum::<f64>() / posterior_samples.len() as f64;
        let mcmc_variance = {
            let mean = mcmc_mean;
            posterior_samples
                .iter()
                .map(|x| (x - mean).powi(2_i32))
                .sum::<f64>()
                / (posterior_samples.len() - 1) as f64
        };
        let ess = effective_sample_size(&posterior_samples);

        println!(
            "โœ… MCMC Results ({} effective samples from {} total):",
            ess as usize,
            posterior_samples.len()
        );
        println!(
            "   - MCMC mean: {:.3} (analytical: {:.3})",
            mcmc_mean, posterior_mean
        );
        println!(
            "   - MCMC variance: {:.4} (analytical: {:.4})",
            mcmc_variance, posterior_variance
        );
        println!("   - Effective Sample Size: {:.0}", ess);
        println!(
            "   - Agreement: {}",
            if (mcmc_mean - posterior_mean).abs() < 0.05 {
                "โœ… Excellent"
            } else {
                "โš ๏ธ Check convergence"
            }
        );
    }

Why MCMC for a Conjugate Model?

Even though we have analytical solutions, MCMC serves important purposes:

  • Validation: Confirms our analytical calculations
  • Flexibility: Easily extends to non-conjugate models
  • Diagnostics: Provides convergence and mixing assessments

Effective Sample Size

The Effective Sample Size (ESS) measures how many independent samples we have. For good MCMC:

  • ESS > 400: Generally adequate for inference
  • ESS < 100: May indicate poor mixing or autocorrelation
  • ESS/Total < 0.1: Consider increasing chain length or improving proposals

Diagnostics & Validation

Model validation ensures our model adequately represents the data-generating process:

    // Posterior predictive checks
    println!("๐Ÿ” Posterior Predictive Validation:");

    // Simulate new data from posterior predictive distribution
    let mut rng = thread_rng();
    let n_pred_samples = 1000;
    let mut predicted_successes = Vec::new();

    for _ in 0..n_pred_samples {
        // Sample bias from posterior
        let p_sample = posterior_dist.sample(&mut rng);

        // Simulate n_flips with this bias
        let mut pred_successes = 0;
        for _ in 0..n_flips {
            if Bernoulli::new(p_sample).unwrap().sample(&mut rng) {
                pred_successes += 1;
            }
        }
        predicted_successes.push(pred_successes);
    }

    // Compare with observed successes
    let pred_mean = predicted_successes.iter().sum::<usize>() as f64 / n_pred_samples as f64;
    let pred_within_range = predicted_successes
        .iter()
        .filter(|&&x| (x as i32 - successes as i32).abs() <= 2)
        .count() as f64
        / n_pred_samples as f64;

    println!("   - Observed successes: {}", successes);
    println!("   - Predicted mean successes: {:.1}", pred_mean);
    println!(
        "   - P(|pred - obs| โ‰ค 2): {:.1}%",
        pred_within_range * 100.0
    );

    if pred_within_range > 0.5 {
        println!("   - โœ… Model fits data well");
    } else {
        println!("   - โš ๏ธ Model may not capture data well");
    }

Posterior Predictive Checks

The posterior predictive distribution answers: "If our model is correct, what data would we expect to see?"

Interpretation:

  • Good fit: Observed data looks typical under the posterior predictive
  • Poor fit: Observed data is extreme under the posterior predictive
  • Model inadequacy: Systematic deviations suggest missing model components

Model Checking Principles

  • Never use the same data for both model fitting and validation
  • Multiple checks are better than single summary statistics
  • Graphical diagnostics often reveal patterns missed by numerical summaries
  • Extreme p-values (< 0.05 or > 0.95) suggest potential model issues

Production Extensions

Decision Theory and Practical Applications

Bayesian inference provides the foundation for optimal decision-making under uncertainty:

    // Bayesian decision theory for fairness testing
    println!("๐ŸŽฒ Decision Analysis:");

    // Define loss function for hypothesis testing
    // H0: coin is fair (p = 0.5), H1: coin is biased (p โ‰  0.5)
    let fairness_threshold = 0.05; // How far from 0.5 counts as "biased"
    let prob_fair = if (posterior_mean - 0.5).abs() < fairness_threshold {
        // Approximate based on credible interval
        0.6
    } else {
        0.2
    };

    println!(
        "   - Posterior probability coin is fair: {:.1}%",
        prob_fair * 100.0
    );
    println!(
        "   - Evidence for bias: {}",
        if prob_fair < 0.3 {
            "Strong"
        } else if prob_fair < 0.7 {
            "Moderate"
        } else {
            "Weak"
        }
    );

    // Expected number of heads in future flips
    let future_flips = 20;
    let expected_heads = posterior_mean * future_flips as f64;
    let uncertainty = (posterior_variance * future_flips as f64).sqrt();

    println!(
        "   - Expected heads in next {} flips: {:.1} ยฑ {:.1}",
        future_flips,
        expected_heads,
        1.96 * uncertainty
    );

    // Practical recommendations
    if (posterior_mean - 0.5).abs() < 0.1 {
        println!("   - ๐Ÿ’ก Recommendation: Treat as approximately fair for practical purposes");
    } else if posterior_mean > 0.5 {
        println!("   - ๐Ÿ’ก Recommendation: Coin appears biased toward heads");
    } else {
        println!("   - ๐Ÿ’ก Recommendation: Coin appears biased toward tails");
    }

Advanced Model Extensions

Real applications often require extensions beyond the basic model:

    // Hierarchical model for multiple coins
    println!("๐Ÿ”ฌ Advanced Modeling Extensions:");

    // Example: What if we had multiple coins?
    let _multi_coin_model = || {
        prob!(
            // Population-level parameters
            let pop_mean <- sample(addr!("population_mean"), Beta::new(1.0, 1.0).unwrap());
            let pop_concentration <- sample(addr!("concentration"), Gamma::new(2.0, 0.5).unwrap());

            // Individual coin bias (hierarchical prior)
            let alpha = pop_mean * pop_concentration;
            let beta = (1.0 - pop_mean) * pop_concentration;
            let coin_bias <- sample(addr!("coin_bias"), Beta::new(alpha, beta).unwrap());

            pure(coin_bias)
        )
    };

    println!("   - ๐Ÿ“ˆ Hierarchical Extension: Population of coins with shared parameters");
    println!("   - ๐Ÿ”„ Sequential Learning: Update beliefs with each new flip");
    println!("   - ๐ŸŽฏ Robust Models: Heavy-tailed priors for outlier resistance");
    println!("   - ๐Ÿ“Š Model Comparison: Bayes factors between fair vs. biased hypotheses");

    // Model comparison example (simplified)
    let fair_model_evidence = -5.2_f64; // Log marginal likelihood for fair model
    let biased_model_evidence = -4.8_f64; // Log marginal likelihood for biased model
    let bayes_factor = (biased_model_evidence - fair_model_evidence).exp();

    println!("   - โš–๏ธ Bayes Factor (biased/fair): {:.2}", bayes_factor);
    if bayes_factor > 3.0 {
        println!("     Evidence favors biased model");
    } else if bayes_factor < 1.0 / 3.0 {
        println!("     Evidence favors fair model");
    } else {
        println!("     Evidence is inconclusive");
    }

Advanced Modeling Scenarios:

  1. Hierarchical Models: Multiple coins with shared population parameters
  2. Sequential Learning: Online updates as new flips arrive
  3. Robust Priors: Heavy-tailed distributions to handle outliers
  4. Model Selection: Comparing fair vs. biased hypotheses using Bayes factors
graph TD
    A[Basic Beta-Bernoulli] --> B[Hierarchical Extension]
    A --> C[Sequential Updates]  
    A --> D[Robust Priors]
    A --> E[Model Comparison]
    B --> F[Population Studies]
    C --> G[Online Learning]
    D --> H[Outlier Resistance] 
    E --> I[Bayes Factors]

Real-World Considerations

When to Use Bayesian vs. Frequentist Methods

Bayesian Advantages:

  • Natural uncertainty quantification: Full posterior distributions
  • Prior knowledge incorporation: Systematic way to include expert knowledge
  • Decision-theoretic framework: Optimal decisions under specified loss functions
  • Sequential updating: Natural online learning as data arrives

Frequentist Advantages:

  • Objective interpretation: No need to specify prior distributions
  • Computational simplicity: Often faster for standard problems
  • Regulatory acceptance: Many standards assume frequentist methods

Practical Guidelines

Use Bayesian methods when:

  • You have relevant prior information to incorporate
  • You need full uncertainty quantification (not just point estimates)
  • You're making sequential decisions as data arrives
  • The cost of wrong decisions varies significantly

Use Frequentist methods when:

  • You want to avoid specifying prior distributions
  • Regulatory requirements mandate specific procedures
  • Computational resources are severely limited
  • The problem has well-established frequentist solutions

Performance Implications

  • Conjugate Models: Analytical solutions are extremely fast
  • MCMC Methods: Scale linearly with data size and number of parameters
  • Memory Usage: Fugue's trace system efficiently manages large parameter spaces
  • Numerical Stability: Log-space computations prevent underflow issues

Common Pitfalls

  1. Improper Priors: Always verify prior distributions integrate to 1
  2. Label Switching: In mixture models, parameter interpretability can change
  3. Convergence Assessment: Always check MCMC diagnostics before making inferences
  4. Prior Sensitivity: Test how conclusions change under different reasonable priors

Exercises

  1. Prior Sensitivity Analysis:

    • Try different Beta priors: Beta(1,1), Beta(5,5), Beta(0.5,0.5)
    • How do the posteriors differ with the same data?
    • When does prior choice matter most?
  2. Sequential Learning:

    • Start with Beta(2,2) prior
    • Update after each flip in sequence
    • Plot how the posterior evolves with each observation
  3. Model Comparison:

    • Implement a "fair coin" model with p = 0.5 exactly
    • Compare evidence for fair vs. biased models using marginal likelihoods
    • What sample size is needed to distinguish p = 0.5 from p = 0.6?
  4. Hierarchical Extension:

    • Model 5 different coins with a shared Beta population prior
    • Each coin has different numbers of flips
    • How does information sharing affect individual coin estimates?

Testing Your Understanding

Comprehensive test suite for validation:

    #[test]
    fn test_coin_flip_model_properties() {
        let test_data = vec![true, true, false, true];
        let mut rng = thread_rng();

        // Test model executes without panics
        let (bias_sample, trace) = runtime::handler::run(
            PriorHandler {
                rng: &mut rng,
                trace: Trace::default(),
            },
            coin_flip_model(test_data.clone()),
        );

        // Bias should be valid probability
        assert!(bias_sample >= 0.0 && bias_sample <= 1.0);

        // Trace should contain expected choices
        assert!(trace.get_f64(&addr!("coin_bias")).is_some());
        assert!(trace.total_log_weight().is_finite());

        // Should have observation sites for each data point
        for _i in 0..test_data.len() {
            // Observations don't create choices, but affect likelihood
            assert!(trace.log_likelihood.is_finite());
        }
    }

    #[test]
    fn test_conjugate_update_correctness() {
        // Test analytical posterior against known values
        let prior_alpha = 2.0;
        let prior_beta = 2.0;
        let successes = 7;
        let failures = 3;

        let posterior_alpha = prior_alpha + successes as f64;
        let posterior_beta = prior_beta + failures as f64;
        let posterior_mean = posterior_alpha / (posterior_alpha + posterior_beta);

        // Should be (2+7)/(2+2+7+3) = 9/14 โ‰ˆ 0.643
        assert!((posterior_mean - 9.0 / 14.0).abs() < 1e-10);

        // Posterior should be more concentrated than prior
        let prior_variance = (2.0 * 2.0) / (4.0_f64.powi(2_i32) * 5.0);
        let posterior_variance = (posterior_alpha * posterior_beta)
            / ((posterior_alpha + posterior_beta).powi(2_i32)
                * (posterior_alpha + posterior_beta + 1.0));
        assert!(posterior_variance < prior_variance);
    }

    #[test]
    fn test_model_with_edge_cases() {
        let mut rng = thread_rng();

        // Test with all heads
        let all_heads = vec![true; 10];
        let (bias, _) = runtime::handler::run(
            PriorHandler {
                rng: &mut rng,
                trace: Trace::default(),
            },
            coin_flip_model(all_heads),
        );
        // Should still be valid probability
        assert!(bias >= 0.0 && bias <= 1.0);

        // Test with all tails
        let all_tails = vec![false; 10];
        let (bias, _) = runtime::handler::run(
            PriorHandler {
                rng: &mut rng,
                trace: Trace::default(),
            },
            coin_flip_model(all_tails),
        );
        assert!(bias >= 0.0 && bias <= 1.0);

        // Test with single flip
        let single_flip = vec![true];
        let (bias, _) = runtime::handler::run(
            PriorHandler {
                rng: &mut rng,
                trace: Trace::default(),
            },
            coin_flip_model(single_flip),
        );
        assert!(bias >= 0.0 && bias <= 1.0);
    }

Next Steps

Now that you understand Bayesian inference fundamentals:

The coin flip problem provides the conceptual foundation for all Bayesian modeling. Every complex model builds on these same principles: prior beliefs, likelihood functions, and posterior inference.

Key Takeaways

โœ… Bayesian inference combines prior knowledge with observed data systematically

โœ… Conjugate priors enable exact analytical solutions for many important problems

โœ… Posterior distributions quantify parameter uncertainty naturally

โœ… Model validation through posterior predictive checks ensures model adequacy

โœ… Decision theory provides a framework for optimal decision-making under uncertainty

Type Safety Features

A comprehensive exploration of Fugue's revolutionary type-safe distribution system and its practical implications for probabilistic programming. This tutorial demonstrates how dependent type theory principles eliminate runtime errors while preserving full statistical expressiveness, making probabilistic programs both safer and more performant.

Learning Objectives

By the end of this tutorial, you will understand:

  • Natural Return Types: How distributions return mathematically appropriate types
  • Compile-Time Safety: How the type system catches errors before runtime
  • Safe Array Indexing: How categorical distributions guarantee bounds safety
  • Parameter Validation: How invalid distributions are caught at construction time
  • Performance Benefits: How type safety eliminates casting overhead and runtime checks

The Type Safety Problem

Traditional probabilistic programming languages force all distributions to return f64, creating a fundamental mismatch between mathematical concepts and their computational representation. This leads to pervasive runtime errors, casting overhead, and semantic confusion.

graph TD
    A["Traditional PPL"] --> B["All distributions โ†’ f64"]
    B --> C["Runtime Errors"]
    B --> D["Casting Overhead"]  
    B --> E["Semantic Confusion"]
    B --> F["Precision Loss"]
    
    G["Fugue PPL"] --> H["Natural Return Types"]
    H --> I["bool for Bernoulli"]
    H --> J["u64 for Poisson"]
    H --> K["usize for Categorical"]
    H --> L["f64 for Normal"]
    
    I --> M["Compile-Time Safety"]
    J --> M
    K --> M
    L --> M
    
    style A fill:#ffcccc
    style G fill:#ccffcc
    style M fill:#ccffff

Traditional PPL Problems

// Demonstrates problems with traditional PPL approaches (shown for contrast)
fn traditional_ppl_problems() {
    println!("=== Traditional PPL Problems (What Fugue Solves) ===\n");

    // In traditional PPLs, everything returns f64, leading to:
    println!("โŒ Traditional PPL Issues:");
    println!("   - Bernoulli returns f64 โ†’ if sample == 1.0 (awkward)");
    println!("   - Poisson returns f64 โ†’ count.round() as u64 (precision loss)");
    println!("   - Categorical returns f64 โ†’ array[sample as usize] (unsafe)");
    println!("   - Runtime type errors and casting overhead");
    println!();
}

The f64 Trap

When everything returns f64, you lose semantic meaning and introduce subtle bugs:

  • if bernoulli_sample == 1.0 - floating-point equality is fragile
  • array[categorical_sample as usize] - unsafe casting can panic
  • poisson_sample.round() as u64 - precision loss in conversions

Mathematical Foundation

Fugue's type system is grounded in dependent type theory, where each distribution is parameterized not just by its parameters , but by its support type .

Formal Type System

For a distribution with parameters and support :

This ensures that sampling operations return values in their natural mathematical domain:

Mathematical ObjectSupport Fugue TypeExample
Bernoulli()booltrue/false
Poisson()u640, 1, 2, ...
Categorical()usizeArray indices
Normal()f64Continuous values

Type-Theoretic Properties

Type Safety Theorem

For any well-formed Fugue program with model and distribution with support :

  1. Type Preservation: If then the sample has type
  2. Progress: All well-typed programs either terminate or can take a computation step
  3. Safety: Well-typed programs do not get "stuck" with runtime type errors

Natural Type System

Fugue eliminates the f64-everything problem by returning mathematically appropriate types:

use fugue::*;
use rand::thread_rng;
// Demonstrate Fugue's natural return types
fn natural_type_system() {
    println!("โœ… Fugue's Natural Type System");
    println!("==============================\n");

    let mut rng = thread_rng();

    // Boolean decisions: Bernoulli โ†’ bool
    let fair_coin = Bernoulli::new(0.5).unwrap();
    let is_heads: bool = fair_coin.sample(&mut rng);

    // Natural conditional logic - no comparisons!
    let outcome = if is_heads {
        "Heads - you win!"
    } else {
        "Tails - try again"
    };
    println!("๐Ÿช™ Coin flip: {} (type: bool)", outcome);

    // Count data: Poisson โ†’ u64
    let customer_arrivals = Poisson::new(5.0).unwrap();
    let arrivals: u64 = customer_arrivals.sample(&mut rng);

    // Direct arithmetic with counts - no casting!
    let service_time = arrivals * 10; // minutes per customer
    println!(
        "๐Ÿ‘ฅ Customers: {} arrivals, {}min service (type: u64)",
        arrivals, service_time
    );

    // Category selection: Categorical โ†’ usize
    let product_preferences = Categorical::new(vec![0.4, 0.35, 0.25]).unwrap();
    let choice: usize = product_preferences.sample(&mut rng);

    // Safe array indexing - guaranteed bounds safety!
    let products = ["Laptop", "Smartphone", "Tablet"];
    println!(
        "๐Ÿ›’ Customer chose: {} (index: {}, type: usize)",
        products[choice], choice
    );

    // Continuous values: Normal โ†’ f64 (unchanged, as expected)
    let measurement = Normal::new(100.0, 5.0).unwrap();
    let reading: f64 = measurement.sample(&mut rng);
    println!("๐Ÿ“ Sensor reading: {:.2} units (type: f64)", reading);

    println!();
}

Type Benefits by Distribution

Bernoulli Distributions

  • Returns: bool - natural boolean logic
  • Benefit: Direct conditional statements without equality comparisons
  • Performance: No floating-point comparisons needed

Count Distributions (Poisson, Binomial)

  • Returns: u64 - natural counting numbers
  • Benefit: Direct arithmetic without casting or precision loss
  • Performance: Integer operations are faster than float conversions

Categorical Distributions

  • Returns: usize - natural array indices
  • Benefit: Guaranteed bounds safety for array indexing
  • Performance: No runtime bounds checking required

Continuous Distributions

  • Returns: f64 - unchanged for appropriate domains
  • Benefit: Expected behavior preserved for mathematical operations

Compile-Time Safety

Fugue's type system catches errors at compile time, eliminating entire classes of runtime failures:

use fugue::*;
use fugue::runtime::interpreters::PriorHandler;
use rand::thread_rng;
// Demonstrate compile-time type safety guarantees
fn compile_time_safety_demo() {
    println!("๐Ÿ›ก๏ธ Compile-Time Type Safety");
    println!("============================\n");

    // Type-safe model composition
    let data_model: Model<(bool, u64, usize, f64)> = prob!(
        let coin_result <- sample(addr!("coin"), Bernoulli::new(0.6).unwrap());
        let event_count <- sample(addr!("events"), Poisson::new(3.0).unwrap());
        let category <- sample(addr!("category"), Categorical::uniform(4).unwrap());
        let measurement <- sample(addr!("measure"), Normal::new(0.0, 1.0).unwrap());

        // Compiler enforces correct types throughout
        pure((coin_result, event_count, category, measurement))
    );

    println!("โœ… Model created with strict type guarantees:");
    println!("   - coin_result: bool (no == 1.0 needed)");
    println!("   - event_count: u64 (direct arithmetic)");
    println!("   - category: usize (safe indexing)");
    println!("   - measurement: f64 (natural continuous)");

    // Execute model safely
    let mut rng = thread_rng();
    let (sample, _trace) = runtime::handler::run(
        PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        data_model,
    );

    println!(
        "๐Ÿ“Š Sample: coin={}, events={}, category={}, value={:.3}",
        sample.0, sample.1, sample.2, sample.3
    );
    println!();
}

Type-Safe Model Composition

Models compose naturally while preserving type information throughout the computation:

Composition Safety

When you compose models Mโ‚ : Model[A] and Mโ‚‚ : Model[B], the result has type Model[(A, B)]. The type system tracks this precisely, ensuring you can't accidentally use a bool where you need a u64.

Safe Array Indexing

One of the most dangerous operations in traditional PPLs is array indexing with categorical samples. Fugue makes this provably safe:

use fugue::*;
use rand::thread_rng;
// Demonstrate safe array indexing with categorical distributions
fn safe_array_indexing() {
    println!("๐ŸŽฏ Safe Array Indexing");
    println!("======================\n");

    let mut rng = thread_rng();

    // Define categories with natural indexing
    let algorithms = ["MCMC", "Variational Inference", "ABC", "SMC", "Exact"];
    let method_weights = vec![0.3, 0.25, 0.2, 0.15, 0.1];

    let method_selector = Categorical::new(method_weights).unwrap();

    println!("๐Ÿงฎ Available inference methods:");
    for (i, method) in algorithms.iter().enumerate() {
        println!("   {}: {}", i, method);
    }
    println!();

    // Sample multiple times to show safety
    for trial in 1..=5 {
        let selected_idx: usize = method_selector.sample(&mut rng);

        // This is GUARANTEED safe - no bounds checking needed!
        let chosen_method = algorithms[selected_idx];

        println!(
            "Trial {}: Selected method '{}' (index {})",
            trial, chosen_method, selected_idx
        );
    }

    println!("\nโœ… All array accesses guaranteed safe by type system!");
    println!();
}

Bounds Safety Guarantee

Categorical Safety Theorem

For a categorical distribution Categorical::new(weights) with k categories:

  • The distribution returns usize values in {0, 1, ..., k-1}
  • Any array with length โ‰ฅ k can be safely indexed with the result
  • No runtime bounds checking is required

Why This Matters

Traditional PPLs require defensive programming:

// Traditional PPL - unsafe!
let category = categorical_sample as usize;
if category < array.len() {
    return array[category];  // Still might panic due to float precision!
} else {
    return default_value;    // Defensive fallback
}

Fugue guarantees safety:

// Fugue - provably safe!
let category: usize = categorical.sample(&mut rng);
return array[category];  // Cannot panic - guaranteed by type system

Parameter Validation

Fugue validates all distribution parameters at construction time, catching invalid configurations before they can cause runtime errors:

use fugue::*;
// Demonstrate parameter validation and error handling
fn parameter_validation_demo() {
    println!("๐Ÿ” Parameter Validation");
    println!("=======================\n");

    println!("Fugue validates parameters at construction time:");
    println!();

    // Valid constructions
    match Normal::new(0.0, 1.0) {
        Ok(_) => println!("โœ… Normal(ฮผ=0.0, ฯƒ=1.0) - valid"),
        Err(e) => println!("โŒ Unexpected error: {:?}", e),
    }

    match Beta::new(2.0, 3.0) {
        Ok(_) => println!("โœ… Beta(ฮฑ=2.0, ฮฒ=3.0) - valid"),
        Err(e) => println!("โŒ Unexpected error: {:?}", e),
    }

    match Categorical::new(vec![0.3, 0.4, 0.3]) {
        Ok(_) => println!("โœ… Categorical([0.3, 0.4, 0.3]) - valid"),
        Err(e) => println!("โŒ Unexpected error: {:?}", e),
    }

    println!();

    // Invalid constructions - caught at compile time with .unwrap()
    // or handled gracefully with pattern matching
    println!("Invalid parameter examples:");

    match Normal::new(0.0, -1.0) {
        Ok(_) => println!("โœ… Normal(ฮผ=0.0, ฯƒ=-1.0) - unexpected success"),
        Err(e) => println!("โŒ Normal(ฮผ=0.0, ฯƒ=-1.0) - {}", e),
    }

    match Beta::new(0.0, 1.0) {
        Ok(_) => println!("โœ… Beta(ฮฑ=0.0, ฮฒ=1.0) - unexpected success"),
        Err(e) => println!("โŒ Beta(ฮฑ=0.0, ฮฒ=1.0) - {}", e),
    }

    match Categorical::new(vec![0.5, 0.6]) {
        // Doesn't sum to 1
        Ok(_) => println!("โœ… Categorical([0.5, 0.6]) - unexpected success"),
        Err(e) => println!("โŒ Categorical([0.5, 0.6]) - {}", e),
    }

    println!("\nโœ… All invalid parameters caught before runtime!");
    println!();
}

Validation Strategy

Fugue uses fail-fast construction with comprehensive parameter checking:

DistributionParametersValidation Rules
Normal(ฮผ, ฯƒ)ฮผ: f64, ฯƒ: f64ฯƒ > 0
Beta(ฮฑ, ฮฒ)ฮฑ: f64, ฮฒ: f64ฮฑ > 0, ฮฒ > 0
Poisson(ฮป)ฮป: f64ฮป > 0
Categorical(p)p: Vec<f64>all(pแตข โ‰ฅ 0), sum(p) โ‰ˆ 1

Design Philosophy

Fugue follows the principle of "make invalid states unrepresentable". By validating at construction time, we ensure that every Distribution object represents a mathematically valid probability distribution.

Type-Safe Observations

Observations in Fugue must match the distribution's return type, providing compile-time guarantees about data consistency:

use fugue::*;
use fugue::runtime::interpreters::PriorHandler;
use rand::thread_rng;
// Demonstrate type-safe observations with automatic type checking
fn type_safe_observations() {
    println!("๐Ÿ”— Type-Safe Observations");
    println!("=========================\n");

    // Observations must match distribution return types
    let observation_model = prob!(
        // Boolean observation - must provide bool
        let _bool_obs <- observe(addr!("coin_obs"),
                                Bernoulli::new(0.7).unwrap(),
                                true); // โœ… bool type matches

        // Count observation - must provide u64
        let _count_obs <- observe(addr!("events_obs"),
                                 Poisson::new(4.0).unwrap(),
                                 5u64); // โœ… u64 type matches

        // Category observation - must provide usize
        let _category_obs <- observe(addr!("choice_obs"),
                                    Categorical::new(vec![0.2, 0.5, 0.3]).unwrap(),
                                    1usize); // โœ… usize type matches

        // Continuous observation - must provide f64
        let _continuous_obs <- observe(addr!("measurement_obs"),
                                      Normal::new(10.0, 2.0).unwrap(),
                                      12.5f64); // โœ… f64 type matches

        pure(())
    );

    println!("โœ… All observations type-checked at compile time!");
    println!("   - Bernoulli observation: bool");
    println!("   - Poisson observation: u64");
    println!("   - Categorical observation: usize");
    println!("   - Normal observation: f64");

    // Execute to verify it works
    let mut rng = thread_rng();
    let (_result, trace) = runtime::handler::run(
        PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        observation_model,
    );

    println!(
        "๐Ÿ“Š Model executed successfully with {} addresses",
        trace.choices.len()
    );
    println!();
}

Observation Type Matching

The type system ensures that observed values match the distribution's natural type:

// โœ… Type-safe observations
observe(addr!("coin"), Bernoulli::new(0.5).unwrap(), true);     // bool
observe(addr!("count"), Poisson::new(3.0).unwrap(), 5u64);     // u64  
observe(addr!("choice"), Categorical::uniform(3).unwrap(), 1usize); // usize
observe(addr!("measure"), Normal::new(0.0, 1.0).unwrap(), 2.5f64);  // f64

// โŒ These would be compile-time errors
observe(addr!("coin"), Bernoulli::new(0.5).unwrap(), 1.0);     // f64 โ‰  bool
observe(addr!("count"), Poisson::new(3.0).unwrap(), 5.0);      // f64 โ‰  u64
observe(addr!("choice"), Categorical::uniform(3).unwrap(), 1);  // i32 โ‰  usize

Advanced Type Composition

Fugue supports complex hierarchical models with full type safety throughout the computation:

use fugue::*;
use fugue::runtime::interpreters::PriorHandler;
use rand::thread_rng;
// Demonstrate advanced type-safe model composition
fn advanced_type_composition() {
    println!("๐Ÿงฉ Advanced Type Composition");
    println!("============================\n");

    // Complex hierarchical model with full type safety
    let hierarchical_model = prob!(
        // Global parameters
        let success_rate <- sample(addr!("global_rate"), Beta::new(1.0, 1.0).unwrap());

        // Group-specific parameters (different types working together)
        let group_sizes <- sequence_vec((0..3).map(|group_id| {
            sample(addr!("group_size", group_id), Poisson::new(10.0).unwrap())
        }).collect());

        let group_successes <- sequence_vec(group_sizes.iter().enumerate().map(|(group_id, &size)| {
            sample(addr!("successes", group_id), Binomial::new(size, success_rate).unwrap())
        }).collect());

        // Category assignments for each group
        let group_categories <- sequence_vec((0..3).map(|group_id| {
            sample(addr!("category", group_id), Categorical::uniform(4).unwrap())
        }).collect());

        // Return complex structured result with full type safety
        pure((success_rate, group_sizes, group_successes, group_categories))
    );

    println!("๐Ÿ—๏ธ Hierarchical model structure:");
    println!("   - Global success rate: f64 (Beta distribution)");
    println!("   - Group sizes: Vec<u64> (Poisson distributions)");
    println!("   - Group successes: Vec<u64> (Binomial distributions)");
    println!("   - Group categories: Vec<usize> (Categorical distributions)");
    println!();

    // Sample from the complex model
    let mut rng = thread_rng();
    let (result, _trace) = runtime::handler::run(
        PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        hierarchical_model,
    );

    let (rate, sizes, successes, categories) = result;

    println!("๐Ÿ“ˆ Sample from hierarchical model:");
    println!("   Global success rate: {:.3}", rate);

    for (i, ((&size, &success), &category)) in sizes
        .iter()
        .zip(successes.iter())
        .zip(categories.iter())
        .enumerate()
    {
        println!(
            "   Group {}: {} trials, {} successes, category {}",
            i, size, success, category
        );
    }

    println!("\nโœ… Complex model composed with full type safety!");
    println!();
}

Hierarchical Type Structure

Complex models maintain precise type information at every level:

graph TD
    A["Global: f64"] --> B["Group Sizes: Vec&lt;u64&gt;"]
    A --> C["Group Successes: Vec&lt;u64&gt;"]  
    A --> D["Group Categories: Vec&lt;usize&gt;"]
    
    B --> E["Model: (f64, Vec&lt;u64&gt;, Vec&lt;u64&gt;, Vec&lt;usize&gt;)"]
    C --> E
    D --> E
    
    style A fill:#e1f5fe
    style E fill:#c8e6c9

Hierarchical Modeling

Fugue's type system scales naturally to arbitrarily complex hierarchical models. Each level maintains its natural types, and the overall model type is compositionally determined by the type rules.

Performance Benefits

Type safety in Fugue eliminates runtime overhead through zero-cost abstractions:

// Demonstrate performance benefits of type safety
fn performance_benefits() {
    println!("โšก Performance Benefits");
    println!("======================\n");

    println!("Type safety eliminates runtime overhead:");
    println!();

    println!("๐Ÿšซ Traditional PPL (f64 everything):");
    println!("   let coin_flip = sample(...); // Returns f64");
    println!("   if coin_flip == 1.0 {{ ... }} // Float comparison");
    println!("   let count = sample(...) as u64; // Casting overhead");
    println!("   array[sample(...) as usize] // Unsafe casting + bounds check");
    println!();

    println!("โœ… Fugue (natural types):");
    println!("   let coin_flip: bool = sample(...); // Returns bool");
    println!("   if coin_flip {{ ... }} // Natural boolean");
    println!("   let count: u64 = sample(...); // Direct u64");
    println!("   array[sample(...)] // Safe usize indexing");
    println!();

    println!("๐ŸŽฏ Benefits:");
    println!("   โœ“ Zero casting overhead");
    println!("   โœ“ No floating-point comparisons for discrete values");
    println!("   โœ“ Eliminated bounds checking for categorical indexing");
    println!("   โœ“ No precision loss from floatโ†’int conversions");
    println!("   โœ“ Compile-time error detection");
    println!();
}

Performance Analysis

OperationTraditional PPLFugueBenefit
Boolean logicFloat comparisonDirect bool~2x faster
Count arithmeticCast + computeDirect u64~1.5x faster
Array indexingCast + bounds checkDirect usize~3x faster
Parameter validationRuntime checksCompile-timeโˆžx faster

Zero-Cost Abstraction Theorem

Fugue's type safety incurs zero runtime cost. The type information is used only at compile time to:

  1. Generate optimized machine code
  2. Eliminate unnecessary runtime checks
  3. Enable compiler optimizations that would be unsafe with dynamic typing

Real-World Applications

Quality Control System

use fugue::*;
let quality_model = prob!(
    // Product defect rate (continuous parameter)
    let defect_rate <- sample(addr!("defect_rate"), Beta::new(1.0, 9.0).unwrap());
    
    // Number of products tested (count data)
    let products_tested <- sample(addr!("tested"), Poisson::new(100.0).unwrap());
    
    // Actual defects found (count with bounds)
    let defects_found <- sample(addr!("defects"), 
                                Binomial::new(products_tested, defect_rate).unwrap());
    
    // Inspector assignment (categorical choice)
    let inspector <- sample(addr!("inspector"), Categorical::uniform(3).unwrap());
    
    // Natural type usage throughout
    pure((defect_rate, products_tested, defects_found, inspector))
);

Medical Diagnosis System

use fugue::*;  
let diagnosis_model = prob!(
    // Prior disease probability (continuous)
    let disease_prob <- sample(addr!("prior"), Beta::new(2.0, 98.0).unwrap());
    
    // Number of symptoms (count) 
    let symptom_count <- sample(addr!("symptoms"), Poisson::new(2.5).unwrap());
    
    // Test result (boolean outcome)
    let test_positive <- sample(addr!("test"), Bernoulli::new(0.95).unwrap());
    
    // Treatment recommendation (categorical)
    let treatment <- sample(addr!("treatment"), 
                           Categorical::new(vec![0.6, 0.3, 0.1]).unwrap());
    
    pure((disease_prob, symptom_count, test_positive, treatment))
);

Production Considerations

Error Handling Strategy

use fugue::*;
// Robust parameter validation
fn create_robust_model(rate: f64, categories: Vec<f64>) -> Result<Model<(f64, usize)>, String> {
    let poisson = Poisson::new(rate)
        .map_err(|e| format!("Invalid Poisson rate {}: {}", rate, e))?;
        
    let categorical = Categorical::new(categories)
        .map_err(|e| format!("Invalid categorical weights: {}", e))?;
    
    Ok(prob!(
        let count <- sample(addr!("count"), poisson);
        let choice <- sample(addr!("choice"), categorical);
        pure((count as f64, choice))
    ))
}

Performance Optimization

Production Optimization

  1. Use appropriate integer types: u32 for small counts, u64 for large counts
  2. Leverage categorical safety: Pre-allocate arrays knowing indices will be valid
  3. Avoid unnecessary conversions: Keep data in natural types throughout pipelines
  4. Profile bottlenecks: Type safety often reveals optimization opportunities

Testing Your Understanding

Exercise 1: Safe Model Construction

Create a model that demonstrates all four natural return types. Ensure it:

  • Uses boolean logic for decision-making
  • Performs arithmetic with count data
  • Safely indexes into arrays
  • Handles continuous parameters
// Exercise framework for testing understanding
fn testing_framework_example() {
    println!("๐Ÿงช Testing Framework Example");
    println!("============================\n");

    let comprehensive_model = prob!(
        // Boolean decision making
        let is_premium <- sample(addr!("premium"), Bernoulli::new(0.3).unwrap());

        // Count data arithmetic
        let base_items <- sample(addr!("base_items"), Poisson::new(5.0).unwrap());
        let bonus_items = if is_premium { base_items + 2 } else { base_items };

        // Safe array indexing
        let service_tier <- sample(addr!("tier"), Categorical::new(vec![0.5, 0.3, 0.2]).unwrap());

        // Continuous parameters
        let satisfaction <- sample(addr!("satisfaction"), Beta::new(2.0, 1.0).unwrap());

        pure((is_premium, bonus_items, service_tier, satisfaction))
    );

    println!("โœ… Comprehensive model demonstrates:");
    println!("   - Boolean logic: Premium account decision");
    println!("   - Count arithmetic: Items calculation with bonus");
    println!("   - Safe indexing: Service tier selection");
    println!("   - Continuous data: Customer satisfaction modeling");

    let mut rng = thread_rng();
    let (premium, items, tier, satisfaction) = runtime::handler::run(
        PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        comprehensive_model,
    )
    .0;

    let tiers = ["Basic", "Standard", "Premium"];
    println!("\n๐Ÿ“Š Sample result:");
    println!("   Premium account: {}", premium);
    println!("   Items received: {}", items);
    println!("   Service tier: {} ({})", tiers[tier], tier);
    println!("   Satisfaction: {:.2}%", satisfaction * 100.0);
    println!();
}

Exercise 2: Parameter Validation

Write a function that attempts to create distributions with both valid and invalid parameters. Handle errors gracefully and provide meaningful error messages.

Exercise 3: Hierarchical Composition

Design a hierarchical model that combines multiple data types across different levels. Ensure type safety is maintained throughout the composition.

Key Takeaways

Type Safety Principles

  1. Natural Types: Each distribution returns its mathematically appropriate type
  2. Compile-Time Safety: Type errors are caught before deployment
  3. Zero-Cost Abstractions: Type safety improves both safety and performance
  4. Compositional: Type safety scales to arbitrary model complexity
  5. Practical: Eliminates common probabilistic programming bugs

Core Benefits:

  • โœ… Eliminated runtime type errors - impossible by construction
  • โœ… Natural mathematical operations - no awkward casting or comparisons
  • โœ… Guaranteed array safety - categorical indexing cannot panic
  • โœ… Performance improvements - zero-cost abstractions enable optimizations
  • โœ… Clear code intent - types document the mathematical structure

Further Reading

  • Working with Distributions - Practical distribution usage patterns
  • Building Complex Models - Advanced composition techniques
  • API Reference - Complete type specifications
  • Types and Programming Languages by Benjamin Pierce - Theoretical foundations
  • Probabilistic Programming & Bayesian Methods for Hackers - Applied Bayesian inference

Trace Manipulation

A deep exploration of Fugue's runtime system and trace manipulation capabilities. This tutorial demonstrates how traces enable sophisticated probabilistic programming techniques including replay, scoring, custom inference, and debugging. Learn how Fugue's execution history recording makes advanced inference algorithms possible while maintaining full type safety.

Learning Objectives

By the end of this tutorial, you will understand:

  • Trace System Architecture: How execution history is recorded and structured
  • Runtime Interpreters: Different ways to execute the same probabilistic model
  • Replay Mechanics: How traces enable MCMC and other inference algorithms
  • Custom Handlers: Building specialized execution strategies for specific needs
  • Memory Optimization: Production-ready techniques for high-throughput scenarios
  • Diagnostic Tools: Convergence assessment and debugging problematic models

The Execution History Problem

Traditional programming languages execute once and discard their execution history. In probabilistic programming, we need to record, manipulate, and reason about random choices to enable sophisticated inference algorithms. Fugue's trace system solves this fundamental challenge.

graph TD
    A["Model Specification"] --> B["Handler Selection"]
    B --> C["PriorHandler<br/>Forward Sampling"]
    B --> D["ReplayHandler<br/>MCMC Proposals"]  
    B --> E["ScoreGivenTrace<br/>Importance Sampling"]
    B --> F["Custom Handler<br/>Specialized Logic"]
    
    C --> G["Execution Trace"]
    D --> G
    E --> G
    F --> G
    
    G --> H["Choice Records"]
    G --> I["Log-Weight Components"]
    G --> J["Type-Safe Values"]
    
    H --> K["Replay"]
    H --> L["Scoring"]
    H --> M["Conditioning"]
    H --> N["Debugging"]
    
    style G fill:#ccffcc
    style K fill:#e1f5fe
    style L fill:#e1f5fe
    style M fill:#e1f5fe
    style N fill:#e1f5fe

Mathematical Foundation

Trace Formalization

A trace records the complete execution history of a probabilistic model, formally represented as:

Where:

  • : Map from addresses to choices
  • : Accumulated prior log-probability
  • : Accumulated observation log-probability
  • : Accumulated factor weights

Total Log-Weight

The total unnormalized log-probability is:

This decomposition enables sophisticated inference algorithms to reason about different sources of probability mass.

Trace Properties

Consistency: For a valid execution, represents the unnormalized log-probability of that specific execution path.

Replayability: Given trace , the model can be deterministically re-executed to produce the same result and weight.

Compositionality: Traces can be modified, combined, and analyzed to implement complex inference strategies.

Basic Trace Inspection

Let's start by understanding how Fugue records execution history:

use fugue::*;
use fugue::runtime::interpreters::PriorHandler;
use rand::{SeedableRng, rngs::StdRng};
// Demonstrate basic trace inspection and manipulation
fn basic_trace_inspection() {
    println!("=== Basic Trace Inspection ===\n");

    // Define a model with multiple types of choices
    let model = prob!(
        let coin <- sample(addr!("coin"), Bernoulli::new(0.7).unwrap());
        let count <- sample(addr!("count"), Poisson::new(3.0).unwrap());
        let category <- sample(addr!("category"), Categorical::uniform(3).unwrap());
        let measurement <- sample(addr!("measurement"), Normal::new(0.0, 1.0).unwrap());

        // Observation adds to likelihood
        let _obs <- observe(addr!("obs"), Normal::new(measurement, 0.1).unwrap(), 0.5);

        pure((coin, count, category, measurement))
    );

    // Execute and inspect the trace
    let mut rng = StdRng::seed_from_u64(12345);
    let (result, trace) = runtime::handler::run(
        PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        model,
    );

    println!("๐Ÿ” Trace Inspection:");
    println!("   - Total choices: {}", trace.choices.len());
    println!("   - Prior log-weight: {:.4}", trace.log_prior);
    println!("   - Likelihood log-weight: {:.4}", trace.log_likelihood);
    println!("   - Factor log-weight: {:.4}", trace.log_factors);
    println!("   - Total log-weight: {:.4}", trace.total_log_weight());
    println!();

    println!("๐Ÿ“Š Individual Choices:");
    for (addr, choice) in &trace.choices {
        println!(
            "   - {}: {:?} (logp: {:.4})",
            addr, choice.value, choice.logp
        );
    }
    println!();

    println!("๐ŸŽฏ Type-Safe Value Access:");
    println!("   - Coin (bool): {:?}", trace.get_bool(&addr!("coin")));
    println!("   - Count (u64): {:?}", trace.get_u64(&addr!("count")));
    println!(
        "   - Category (usize): {:?}",
        trace.get_usize(&addr!("category"))
    );
    println!(
        "   - Measurement (f64): {:?}",
        trace.get_f64(&addr!("measurement"))
    );

    let (coin, count, category, measurement) = result;
    println!(
        "   - Result: coin={}, count={}, category={}, measurement={:.3}",
        coin, count, category, measurement
    );
    println!();
}

Trace Structure Analysis

Every trace contains three critical components:

  1. Choices Map: Records every random decision with its address, value, and log-probability
  2. Weight Decomposition: Separates prior, likelihood, and factor contributions
  3. Type-Safe Values: Maintains natural types throughout execution

Debugging with Traces

The trace decomposition immediately shows you:

  • Prior weight: How likely your parameter values are under priors
  • Likelihood weight: How well your model fits the observed data
  • Factor weight: Contribution from explicit factor() statements

Replay Mechanics

The replay system is the foundation of MCMC algorithms. It allows deterministic re-execution with modified random choices:

use fugue::*;
use fugue::runtime::interpreters::*;
use fugue::runtime::trace::*;
use rand::{SeedableRng, rngs::StdRng};
// Demonstrate trace replay mechanics for MCMC
fn replay_mechanics() {
    println!("=== Trace Replay Mechanics ===\n");

    // Define a simple model
    let make_model = || {
        prob!(
            let mu <- sample(addr!("mu"), Normal::new(0.0, 2.0).unwrap());
            let _obs <- observe(addr!("y"), Normal::new(mu, 1.0).unwrap(), 1.5);
            pure(mu)
        )
    };

    // 1. Generate initial trace
    let mut rng = StdRng::seed_from_u64(42);
    let (mu1, trace1) = runtime::handler::run(
        PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        make_model(),
    );

    println!("๐ŸŽฒ Original Execution:");
    println!("   - mu = {:.3}", mu1);
    println!("   - Prior logp: {:.3}", trace1.log_prior);
    println!("   - Likelihood logp: {:.3}", trace1.log_likelihood);
    println!("   - Total logp: {:.3}", trace1.total_log_weight());
    println!();

    // 2. Replay with exact same trace
    let mut rng2 = StdRng::seed_from_u64(42); // New RNG for replay
    let (mu2, trace2) = runtime::handler::run(
        ReplayHandler {
            rng: &mut rng2,
            base: trace1.clone(),
            trace: Trace::default(),
        },
        make_model(),
    );

    println!("๐Ÿ”„ Exact Replay:");
    println!("   - mu = {:.3} (should match original)", mu2);
    println!("   - Values match: {}", mu1 == mu2);
    println!(
        "   - Traces match: {}",
        trace1.total_log_weight() == trace2.total_log_weight()
    );
    println!();

    // 3. Modify trace for proposal
    let mut modified_trace = trace1.clone();
    // Modify the mu value (MCMC proposal)
    if let Some(choice) = modified_trace.choices.get_mut(&addr!("mu")) {
        let old_value = choice.value.as_f64().unwrap();
        let new_value = old_value + 0.1; // Small proposal step
        choice.value = ChoiceValue::F64(new_value);

        // Recompute log-probability under the distribution
        let normal_dist = Normal::new(0.0, 2.0).unwrap();
        choice.logp = normal_dist.log_prob(&new_value);

        println!("๐Ÿ”ง Modified Trace (Proposal):");
        println!("   - Old mu: {:.3}", old_value);
        println!("   - New mu: {:.3}", new_value);
        println!(
            "   - Old logp: {:.3}",
            trace1
                .get_f64(&addr!("mu"))
                .map(|v| Normal::new(0.0, 2.0).unwrap().log_prob(&v))
                .unwrap_or(0.0)
        );
        println!("   - New logp: {:.3}", choice.logp);
    }

    // 4. Score the modified trace
    let mut rng3 = StdRng::seed_from_u64(42); // New RNG for proposal
    let (mu3, trace3) = runtime::handler::run(
        ReplayHandler {
            rng: &mut rng3,
            base: modified_trace,
            trace: Trace::default(),
        },
        make_model(),
    );

    println!("   - Proposal result: mu = {:.3}", mu3);
    println!("   - Proposal total logp: {:.3}", trace3.total_log_weight());
    println!(
        "   - Accept/Reject ratio: {:.3}",
        (trace3.total_log_weight() - trace1.total_log_weight()).exp()
    );
    println!();
}

MCMC Proposal Mechanism

sequenceDiagram
    participant M as Model
    participant T1 as Current Trace
    participant T2 as Proposal Trace
    participant A as Accept/Reject
    
    M->>T1: Execute with PriorHandler
    Note over T1: Record current state
    
    T1->>T2: Modify choice values
    Note over T2: Create proposal
    
    M->>T2: Execute with ReplayHandler  
    Note over T2: Score proposal
    
    T2->>A: Compare log-weights
    Note over A: Accept if log(u) < ฮ”log(w)
    
    A->>T1: Keep current (if rejected)
    A->>T2: Accept proposal (if accepted)

The key insight: the same model specification can be executed with different random choices by manipulating the trace and using replay.

Custom Handlers

Handlers define how probabilistic effects are interpreted. You can create custom handlers for specialized inference algorithms:

use fugue::*;
use fugue::runtime::{handler::Handler, trace::*};
use rand::{SeedableRng, rngs::StdRng};
// Demonstrate custom handler implementation
struct DebugHandler<R: rand::Rng> {
    rng: R,
    trace: Trace,
    debug_info: Vec<String>,
}

impl<R: rand::Rng> DebugHandler<R> {
    fn new(rng: R) -> Self {
        Self {
            rng,
            trace: Trace::default(),
            debug_info: Vec::new(),
        }
    }
}

impl<R: rand::Rng> Handler for DebugHandler<R> {
    fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
        let value = dist.sample(&mut self.rng);
        let logp = dist.log_prob(&value);

        self.debug_info.push(format!(
            "SAMPLE f64 at {}: {} (logp: {:.3})",
            addr, value, logp
        ));

        self.trace.choices.insert(
            addr.clone(),
            Choice {
                addr: addr.clone(),
                value: ChoiceValue::F64(value),
                logp,
            },
        );
        self.trace.log_prior += logp;

        value
    }

    fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
        let value = dist.sample(&mut self.rng);
        let logp = dist.log_prob(&value);

        self.debug_info.push(format!(
            "SAMPLE bool at {}: {} (logp: {:.3})",
            addr, value, logp
        ));

        self.trace.choices.insert(
            addr.clone(),
            Choice {
                addr: addr.clone(),
                value: ChoiceValue::Bool(value),
                logp,
            },
        );
        self.trace.log_prior += logp;

        value
    }

    fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
        let value = dist.sample(&mut self.rng);
        let logp = dist.log_prob(&value);

        self.debug_info.push(format!(
            "SAMPLE u64 at {}: {} (logp: {:.3})",
            addr, value, logp
        ));

        self.trace.choices.insert(
            addr.clone(),
            Choice {
                addr: addr.clone(),
                value: ChoiceValue::U64(value),
                logp,
            },
        );
        self.trace.log_prior += logp;

        value
    }

    fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
        let value = dist.sample(&mut self.rng);
        let logp = dist.log_prob(&value);

        self.debug_info.push(format!(
            "SAMPLE usize at {}: {} (logp: {:.3})",
            addr, value, logp
        ));

        self.trace.choices.insert(
            addr.clone(),
            Choice {
                addr: addr.clone(),
                value: ChoiceValue::Usize(value),
                logp,
            },
        );
        self.trace.log_prior += logp;

        value
    }

    fn on_observe_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>, value: f64) {
        let logp = dist.log_prob(&value);

        self.debug_info.push(format!(
            "OBSERVE f64 at {}: {} (logp: {:.3})",
            addr, value, logp
        ));

        self.trace.log_likelihood += logp;
    }

    fn on_observe_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>, value: bool) {
        let logp = dist.log_prob(&value);

        self.debug_info.push(format!(
            "OBSERVE bool at {}: {} (logp: {:.3})",
            addr, value, logp
        ));

        self.trace.log_likelihood += logp;
    }

    fn on_observe_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>, value: u64) {
        let logp = dist.log_prob(&value);

        self.debug_info.push(format!(
            "OBSERVE u64 at {}: {} (logp: {:.3})",
            addr, value, logp
        ));

        self.trace.log_likelihood += logp;
    }

    fn on_observe_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>, value: usize) {
        let logp = dist.log_prob(&value);

        self.debug_info.push(format!(
            "OBSERVE usize at {}: {} (logp: {:.3})",
            addr, value, logp
        ));

        self.trace.log_likelihood += logp;
    }

    fn on_factor(&mut self, logw: f64) {
        self.debug_info.push(format!("FACTOR: {:.3}", logw));
        self.trace.log_factors += logw;
    }

    fn finish(self) -> Trace {
        self.trace
    }
}

fn custom_handler_demo() {
    println!("=== Custom Handler Demo ===\n");

    let model = prob!(
        let prior_mu <- sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap());
        let success <- sample(addr!("success"), Bernoulli::new(0.6).unwrap());

        let _obs1 <- observe(addr!("data1"), Normal::new(prior_mu, 0.5).unwrap(), 1.2);
        let _obs2 <- observe(addr!("data2"), Bernoulli::new(if success { 0.8 } else { 0.2 }).unwrap(), true);

        // Add a factor for soft constraints
        let _factor_result <- factor(if prior_mu > 0.0 { 0.1 } else { -0.1 });

        pure((prior_mu, success))
    );

    let rng = StdRng::seed_from_u64(67890);
    let debug_handler = DebugHandler::new(rng);

    let (result, final_trace) = runtime::handler::run(debug_handler, model);

    println!("๐Ÿ” Debug Handler Output:");
    println!("   - Result: {:?}", result);
    println!(
        "   - Total log-weight: {:.4}",
        final_trace.total_log_weight()
    );
    println!();

    println!("๐Ÿ“ Execution Log:");
    println!("   - {} operations recorded", final_trace.choices.len());

    println!();
}

Handler Architecture

graph TD
    A["Model Effects"] --> B["Handler Dispatch"]
    B --> C["on_sample_f64()"]
    B --> D["on_sample_bool()"] 
    B --> E["on_sample_u64()"]
    B --> F["on_sample_usize()"]
    B --> G["on_observe_*()"]
    B --> H["on_factor()"]
    
    C --> I["Custom Logic"]
    D --> I
    E --> I
    F --> I
    G --> I
    H --> I
    
    I --> J["Trace Update"]
    I --> K["Side Effects"]
    I --> L["Logging/Debug"]
    
    style I fill:#ccffcc

Built-in Handler Types

HandlerPurposeUse Case
PriorHandlerForward samplingGenerate data, initialization
ReplayHandlerDeterministic replayMCMC, validation
ScoreGivenTraceCompute log-probabilityImportance sampling
SafeReplayHandlerError-resilient replayProduction MCMC
SafeScoreGivenTraceSafe scoringRobust inference

Trace Scoring

Scoring computes the log-probability of a specific execution path, essential for importance sampling and model comparison:

use fugue::*;
use fugue::runtime::interpreters::*;
use fugue::runtime::trace::*;
use rand::{SeedableRng, rngs::StdRng};
// Demonstrate trace scoring for importance sampling
fn trace_scoring_demo() {
    println!("=== Trace Scoring Demo ===\n");

    let make_model = || {
        prob!(
            let theta <- sample(addr!("theta"), Beta::new(2.0, 2.0).unwrap());

            // Multiple observations
            let _obs1 <- observe(addr!("y1"), Bernoulli::new(theta).unwrap(), true);
            let _obs2 <- observe(addr!("y2"), Bernoulli::new(theta).unwrap(), true);
            let _obs3 <- observe(addr!("y3"), Bernoulli::new(theta).unwrap(), false);

            pure(theta)
        )
    };

    // Generate a trace from the prior
    let mut rng = StdRng::seed_from_u64(111);
    let (theta_val, prior_trace) = runtime::handler::run(
        PriorHandler {
            rng: &mut rng,
            trace: Trace::default(),
        },
        make_model(),
    );

    println!("๐ŸŽฒ Prior Sample:");
    println!("   - theta = {:.3}", theta_val);
    println!("   - Prior logp: {:.3}", prior_trace.log_prior);
    println!("   - Likelihood logp: {:.3}", prior_trace.log_likelihood);
    println!("   - Total logp: {:.3}", prior_trace.total_log_weight());
    println!();

    // Now score this trace under the model (should get same result)
    let (theta_scored, scored_trace) = runtime::handler::run(
        ScoreGivenTrace {
            base: prior_trace.clone(),
            trace: Trace::default(),
        },
        make_model(),
    );

    println!("๐Ÿ“Š Scoring Same Trace:");
    println!("   - theta = {:.3} (should match)", theta_scored);
    println!("   - Prior logp: {:.3}", scored_trace.log_prior);
    println!("   - Likelihood logp: {:.3}", scored_trace.log_likelihood);
    println!("   - Total logp: {:.3}", scored_trace.total_log_weight());
    println!(
        "   - Weights match: {}",
        (prior_trace.total_log_weight() - scored_trace.total_log_weight()).abs() < 1e-10
    );
    println!();

    // Create a modified trace for importance sampling
    let mut importance_trace = prior_trace.clone();

    // Change theta to a different value
    if let Some(choice) = importance_trace.choices.get_mut(&addr!("theta")) {
        let new_theta = 0.8; // High success probability
        choice.value = ChoiceValue::F64(new_theta);
        choice.logp = Beta::new(2.0, 2.0).unwrap().log_prob(&new_theta);

        println!("๐ŸŽฏ Importance Sample:");
        println!("   - Modified theta to: {:.3}", new_theta);
        println!("   - New prior logp: {:.3}", choice.logp);
    }

    // Score under original model
    let (theta_is, is_trace) = runtime::handler::run(
        ScoreGivenTrace {
            base: importance_trace,
            trace: Trace::default(),
        },
        make_model(),
    );

    println!("   - IS result: theta = {:.3}", theta_is);
    println!("   - IS total logp: {:.3}", is_trace.total_log_weight());
    println!(
        "   - Importance weight: {:.3}",
        is_trace.total_log_weight() - prior_trace.total_log_weight()
    );
    println!();
}

Importance Sampling Theory

Given proposal trace and target model :

Where the importance weight is:

Fugue's scoring system automatically computes for any trace under any model.

Numerical Stability

Always work in log-space for importance weights. Direct probability ratios quickly underflow or overflow for realistic models.

Memory Optimization

For production workloads, efficient memory management is crucial:

use fugue::*;
use fugue::runtime::{interpreters::PriorHandler, memory::*};
use rand::{SeedableRng, rngs::StdRng};
// Demonstrate memory-optimized trace handling
fn memory_optimization_demo() {
    println!("=== Memory Optimization Demo ===\n");

    // Simple model for batch processing
    let make_model = |obs_val: f64| {
        prob!(
            let mu <- sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap());
            let sigma <- sample(addr!("sigma"), Gamma::new(2.0, 0.5).unwrap());
            let _obs <- observe(addr!("y"), Normal::new(mu, sigma).unwrap(), obs_val);
            pure((mu, sigma))
        )
    };

    println!("๐Ÿญ Batch Processing with Memory Pool:");

    // Simulate batch inference with trace reuse
    let observations = [1.0, 1.2, 0.8, 1.5, 0.9];
    let mut results = Vec::new();

    // Use copy-on-write traces for efficiency
    let base_trace = CowTrace::new();

    for (i, &obs) in observations.iter().enumerate() {
        let mut rng = StdRng::seed_from_u64(200 + i as u64);
        let handler = PriorHandler {
            rng: &mut rng,
            trace: base_trace.to_trace(), // Convert to regular trace
        };

        let (result, trace) = runtime::handler::run(handler, make_model(obs));
        results.push((result, trace));

        println!(
            "   Sample {}: mu={:.3}, sigma={:.3}, obs={:.1}, logp={:.3}",
            i + 1,
            result.0,
            result.1,
            obs,
            results[i].1.total_log_weight()
        );
    }

    println!();
    println!("๐Ÿ“Š Batch Statistics:");
    let mu_mean = results.iter().map(|((mu, _), _)| mu).sum::<f64>() / results.len() as f64;
    let sigma_mean =
        results.iter().map(|((_, sigma), _)| sigma).sum::<f64>() / results.len() as f64;
    let logp_mean = results
        .iter()
        .map(|(_, trace)| trace.total_log_weight())
        .sum::<f64>()
        / results.len() as f64;

    println!("   - Average mu: {:.3}", mu_mean);
    println!("   - Average sigma: {:.3}", sigma_mean);
    println!("   - Average log-probability: {:.3}", logp_mean);
    println!();

    println!("๐Ÿ”ง Trace Builder Demo:");

    // Demonstrate efficient trace building
    let _builder = TraceBuilder::new();
    // Note: TraceBuilder API may not have reserve_choices method
    // This is a conceptual example of memory pre-allocation

    // Manually construct a trace (rarely needed, but shows internals)
    let demo_trace = Trace {
        choices: [
            (
                addr!("param1"),
                Choice {
                    addr: addr!("param1"),
                    value: ChoiceValue::F64(0.5),
                    logp: -1.4,
                },
            ),
            (
                addr!("param2"),
                Choice {
                    addr: addr!("param2"),
                    value: ChoiceValue::Bool(true),
                    logp: -0.7,
                },
            ),
        ]
        .iter()
        .cloned()
        .collect(),
        log_prior: -2.1,
        log_likelihood: -0.5,
        log_factors: 0.0,
    };

    println!(
        "   - Manual trace: {} choices, total logp: {:.3}",
        demo_trace.choices.len(),
        demo_trace.total_log_weight()
    );
    println!();
}

Production Memory Strategies

  1. Copy-on-Write Traces: Share read-only data, copy only when modified
  2. Trace Pooling: Reuse allocated memory across multiple inferences
  3. Pre-sized Allocation: Reserve space for expected number of choices
  4. Batch Processing: Amortize allocation costs across many executions

Memory Benchmarking

For high-throughput scenarios:

  • Use TracePool for batch processing
  • Pre-size trace builders when choice count is predictable
  • Profile memory allocation patterns in your specific use case

Diagnostic Tools

Fugue provides comprehensive tools for analyzing trace quality and convergence:

use fugue::*;
use fugue::runtime::interpreters::PriorHandler;
use fugue::inference::diagnostics::*;
use rand::{SeedableRng, rngs::StdRng};
// Demonstrate diagnostic tools for trace analysis
fn diagnostic_tools_demo() {
    println!("=== Diagnostic Tools Demo ===\n");

    // Generate multiple MCMC-like traces for diagnostics
    let make_model = || {
        prob!(
            let mu <- sample(addr!("mu"), Normal::new(0.0, 2.0).unwrap());
            let precision <- sample(addr!("precision"), Gamma::new(2.0, 1.0).unwrap());

            // Multiple observations
            let _obs1 <- observe(addr!("y1"), Normal::new(mu, 1.0/precision.sqrt()).unwrap(), 1.0);
            let _obs2 <- observe(addr!("y2"), Normal::new(mu, 1.0/precision.sqrt()).unwrap(), 1.2);
            let _obs3 <- observe(addr!("y3"), Normal::new(mu, 1.0/precision.sqrt()).unwrap(), 0.8);

            pure((mu, precision))
        )
    };

    // Simulate two chains
    println!("๐Ÿ”— Generating MCMC-like traces:");
    let mut chain1 = Vec::new();
    let mut chain2 = Vec::new();

    // Chain 1
    for i in 0..20 {
        let mut rng = StdRng::seed_from_u64(300 + i);
        let (_, trace) = runtime::handler::run(
            PriorHandler {
                rng: &mut rng,
                trace: Trace::default(),
            },
            make_model(),
        );
        chain1.push(trace);
    }

    // Chain 2 (different seed)
    for i in 0..20 {
        let mut rng = StdRng::seed_from_u64(400 + i);
        let (_, trace) = runtime::handler::run(
            PriorHandler {
                rng: &mut rng,
                trace: Trace::default(),
            },
            make_model(),
        );
        chain2.push(trace);
    }

    println!("   - Chain 1: {} samples", chain1.len());
    println!("   - Chain 2: {} samples", chain2.len());
    println!();

    // Extract parameter values
    let mu_values1 = extract_f64_values(&chain1, &addr!("mu"));
    let mu_values2 = extract_f64_values(&chain2, &addr!("mu"));
    let precision_values1 = extract_f64_values(&chain1, &addr!("precision"));
    let _precision_values2 = extract_f64_values(&chain2, &addr!("precision"));

    println!("๐Ÿ“ˆ Parameter Summaries:");
    println!(
        "   - mu chain1: mean={:.3}, min={:.3}, max={:.3}",
        mu_values1.iter().sum::<f64>() / mu_values1.len() as f64,
        mu_values1.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
        mu_values1.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
    );

    println!(
        "   - mu chain2: mean={:.3}, min={:.3}, max={:.3}",
        mu_values2.iter().sum::<f64>() / mu_values2.len() as f64,
        mu_values2.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
        mu_values2.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
    );

    println!(
        "   - precision chain1: mean={:.3}, min={:.3}, max={:.3}",
        precision_values1.iter().sum::<f64>() / precision_values1.len() as f64,
        precision_values1
            .iter()
            .fold(f64::INFINITY, |a, &b| a.min(b)),
        precision_values1
            .iter()
            .fold(f64::NEG_INFINITY, |a, &b| a.max(b))
    );

    // Compute R-hat (simplified version)
    let chains_mu = vec![chain1.clone(), chain2.clone()];
    let r_hat_mu = r_hat_f64(&chains_mu, &addr!("mu"));
    let r_hat_precision = r_hat_f64(&chains_mu, &addr!("precision"));

    println!();
    println!("๐ŸŽฏ Convergence Diagnostics:");
    println!("   - R-hat for mu: {:.4} (< 1.1 is good)", r_hat_mu);
    println!(
        "   - R-hat for precision: {:.4} (< 1.1 is good)",
        r_hat_precision
    );

    // Parameter summary
    let mu_summary = summarize_f64_parameter(&chains_mu, &addr!("mu"));
    let q5 = mu_summary.quantiles.get("2.5%").unwrap_or(&f64::NAN);
    let q95 = mu_summary.quantiles.get("97.5%").unwrap_or(&f64::NAN);
    println!(
        "   - mu summary: mean={:.3}, std={:.3}, q2.5={:.3}, q97.5={:.3}",
        mu_summary.mean, mu_summary.std, q5, q95
    );

    println!();
    println!("๐Ÿ“Š Trace Quality Assessment:");
    let total_logp_chain1: f64 = chain1.iter().map(|t| t.total_log_weight()).sum();
    let total_logp_chain2: f64 = chain2.iter().map(|t| t.total_log_weight()).sum();
    let avg_logp1 = total_logp_chain1 / chain1.len() as f64;
    let avg_logp2 = total_logp_chain2 / chain2.len() as f64;

    println!("   - Chain 1 avg log-probability: {:.3}", avg_logp1);
    println!("   - Chain 2 avg log-probability: {:.3}", avg_logp2);
    println!(
        "   - Chains similar quality: {}",
        (avg_logp1 - avg_logp2).abs() < 0.5
    );
    println!();
}

Convergence Assessment

R-hat Statistic: Compares between-chain variance to within-chain variance

Where:

  • : Estimated marginal posterior variance
  • : Within-chain variance

Interpretation:

  • : Good convergence
  • : Chains haven't mixed well, need more samples
  • : Poor convergence, investigate model or algorithm

Parameter Summaries

For each parameter, compute:

  • Mean and Standard Deviation: Central tendency and spread
  • Quantiles: 5%, 25%, 50%, 75%, 95% for uncertainty intervals
  • Effective Sample Size: Accounting for autocorrelation

Advanced Debugging

When models behave unexpectedly, trace analysis reveals the root causes:

use fugue::*;
use fugue::runtime::interpreters::*;
use fugue::runtime::trace::*;
use rand::{SeedableRng, rngs::StdRng};
// Demonstrate advanced debugging techniques
fn advanced_debugging_demo() {
    println!("=== Advanced Debugging Techniques ===\n");

    // Model with potential numerical issues
    let _problematic_model = prob!(
        let scale <- sample(addr!("scale"), Exponential::new(1.0).unwrap());

        // This could cause numerical issues if scale is very small
        let precision <- sample(addr!("precision"), Gamma::new(1.0, scale).unwrap());

        let mu <- sample(addr!("mu"), Normal::new(0.0, 1.0 / precision.sqrt()).unwrap());

        // Observation that might conflict
        let _obs <- observe(addr!("y"), Normal::new(mu, 0.01).unwrap(), 10.0);

        pure((scale, precision, mu))
    );

    println!("๐Ÿšจ Debugging Problematic Model:");

    // Try multiple executions to find issues
    for attempt in 1..=5 {
        let mut rng = StdRng::seed_from_u64(500 + attempt);
        let problematic_model_copy = prob!(
            let scale <- sample(addr!("scale"), Exponential::new(1.0).unwrap());

            // This could cause numerical issues if scale is very small
            let precision <- sample(addr!("precision"), Gamma::new(1.0, scale).unwrap());

            let mu <- sample(addr!("mu"), Normal::new(0.0, 1.0 / precision.sqrt()).unwrap());

            // Observation that might conflict
            let _obs <- observe(addr!("y"), Normal::new(mu, 0.01).unwrap(), 10.0);

            pure((scale, precision, mu))
        );

        let (result, trace) = runtime::handler::run(
            PriorHandler {
                rng: &mut rng,
                trace: Trace::default(),
            },
            problematic_model_copy,
        );

        let (scale, precision, mu) = result;
        let total_logp = trace.total_log_weight();

        println!(
            "   Attempt {}: scale={:.6}, precision={:.6}, mu={:.3}, logp={:.3}",
            attempt, scale, precision, mu, total_logp
        );

        // Check for numerical issues
        if !total_logp.is_finite() {
            println!("     โš ๏ธ Non-finite log-probability detected!");
        }

        if precision < 1e-6 {
            println!("     โš ๏ธ Very small precision: {:.8}", precision);
        }

        if mu.abs() > 5.0 {
            println!("     โš ๏ธ Extreme mu value: {:.3}", mu);
        }

        // Examine individual components
        println!(
            "     Components: prior={:.3}, likelihood={:.3}, factors={:.3}",
            trace.log_prior, trace.log_likelihood, trace.log_factors
        );
    }

    println!();
    println!("๐Ÿ” Trace Validation:");

    // Create a trace with known good values for validation
    let validation_trace = Trace {
        choices: [
            (
                addr!("scale"),
                Choice {
                    addr: addr!("scale"),
                    value: ChoiceValue::F64(1.0),
                    logp: Exponential::new(1.0).unwrap().log_prob(&1.0),
                },
            ),
            (
                addr!("precision"),
                Choice {
                    addr: addr!("precision"),
                    value: ChoiceValue::F64(2.0),
                    logp: Gamma::new(1.0, 1.0).unwrap().log_prob(&2.0),
                },
            ),
            (
                addr!("mu"),
                Choice {
                    addr: addr!("mu"),
                    value: ChoiceValue::F64(0.5),
                    logp: Normal::new(0.0, 1.0 / (2.0_f64).sqrt())
                        .unwrap()
                        .log_prob(&0.5),
                },
            ),
        ]
        .iter()
        .cloned()
        .collect(),
        log_prior: 0.0,
        log_likelihood: 0.0,
        log_factors: 0.0,
    };

    // Score this validation trace
    let validation_model = prob!(
        let scale <- sample(addr!("scale"), Exponential::new(1.0).unwrap());
        let precision <- sample(addr!("precision"), Gamma::new(1.0, scale).unwrap());
        let mu <- sample(addr!("mu"), Normal::new(0.0, 1.0 / precision.sqrt()).unwrap());
        let _obs <- observe(addr!("y"), Normal::new(mu, 0.01).unwrap(), 10.0);
        pure((scale, precision, mu))
    );

    let (val_result, val_trace) = runtime::handler::run(
        ScoreGivenTrace {
            base: validation_trace,
            trace: Trace::default(),
        },
        validation_model,
    );

    println!("   - Validation result: {:?}", val_result);
    println!("   - Validation logp: {:.3}", val_trace.total_log_weight());
    println!(
        "   - Validation finite: {}",
        val_trace.total_log_weight().is_finite()
    );
    println!();
}

Debugging Strategy

graph TD
    A["Model Issues"] --> B["Check Log-Weights"]
    B --> C["Infinite/NaN Values?"]
    B --> D["Extreme Parameter Values?"]
    B --> E["Prior-Likelihood Conflict?"]
    
    C --> F["Examine Individual Choices"]
    D --> G["Check Parameter Ranges"]
    E --> H["Validate Observations"]
    
    F --> I["Fix Distribution Parameters"]
    G --> J["Add Constraints/Priors"] 
    H --> K["Verify Data Consistency"]
    
    I --> L["Test with Validation Trace"]
    J --> L
    K --> L
    
    style C fill:#ffcccc
    style D fill:#ffcccc
    style E fill:#ffcccc

Common Issues and Solutions

ProblemSymptomSolution
Numerical overflowInf log-weightsUse log-space throughout
Parameter explosionExtreme valuesAdd regularizing priors
Prior-data conflictVery negative likelihoodCheck data preprocessing
Precision issuesUnstable gradientsUse higher precision types

Production Debugging

Always validate your models with:

  1. Known-good traces with reasonable parameter values
  2. Synthetic data where you know the true parameters
  3. Multiple random seeds to check consistency
  4. Finite-value assertions in your handlers

Real-World Applications

Custom MCMC Algorithm

use fugue::*;
use fugue::runtime::{interpreters::*, trace::*};

struct CustomMCMC<R: rand::Rng> {
    rng: R,
    current_trace: Trace,
    step_size: f64,
}

impl<R: rand::Rng> CustomMCMC<R> {
    fn step<F>(&mut self, model_fn: F) -> bool 
    where F: Fn() -> Model<f64>
    {
        // Create proposal by modifying current trace
        let mut proposal_trace = self.current_trace.clone();
        
        // Modify a random choice (simplified)
        if let Some((addr, choice)) = proposal_trace.choices.iter_mut().next() {
            if let Some(current_val) = choice.value.as_f64() {
                let proposal_val = current_val + self.step_size * 
                    Normal::new(0.0, 1.0).unwrap().sample(&mut self.rng);
                choice.value = ChoiceValue::F64(proposal_val);
            }
        }
        
        // Score proposal
        let (_, scored_trace) = runtime::handler::run(
            ScoreGivenTrace::new(proposal_trace),
            model_fn()
        );
        
        // Accept/reject based on Metropolis criterion
        let log_alpha = scored_trace.total_log_weight() - 
                       self.current_trace.total_log_weight();
        
        if log_alpha > 0.0 || 
           self.rng.gen::<f64>().ln() < log_alpha {
            self.current_trace = scored_trace;
            true // Accepted
        } else {
            false // Rejected  
        }
    }
}

Production Inference Pipeline

use fugue::*;
use fugue::runtime::memory::TracePool;

struct InferencePipeline {
    pool: TracePool,
    diagnostics: Vec<f64>,
}

impl InferencePipeline {
    fn run_batch<F>(&mut self, 
                   model_fn: F, 
                   n_samples: usize) -> Vec<(f64, Trace)>
    where F: Fn() -> Model<f64> + Copy
    {
        let mut results = Vec::with_capacity(n_samples);
        
        for _ in 0..n_samples {
            // Get pooled trace to avoid allocation
            let pooled_trace = self.pool.get_trace();
            
            let mut rng = rand::thread_rng();
            let handler = PriorHandler { 
                rng: &mut rng, 
                trace: pooled_trace 
            };
            
            let (result, trace) = runtime::handler::run(handler, model_fn());
            
            // Record diagnostics
            self.diagnostics.push(trace.total_log_weight());
            
            results.push((result, trace));
        }
        
        results
    }
    
    fn convergence_summary(&self) -> (f64, f64) {
        let mean = self.diagnostics.iter().sum::<f64>() / self.diagnostics.len() as f64;
        let var = self.diagnostics.iter()
            .map(|x| (x - mean).powi(2))
            .sum::<f64>() / (self.diagnostics.len() - 1) as f64;
        (mean, var.sqrt())
    }
}

Best Practices

Handler Development

Custom Handler Guidelines

  1. Type Safety: Always match handler methods to distribution return types
  2. Error Handling: Use Result types for production handlers
  3. State Management: Keep handler state minimal and well-documented
  4. Performance: Pre-allocate collections when possible
  5. Testing: Validate against known-good traces

Memory Management

Production Optimization

  1. Profile First: Measure actual memory usage patterns
  2. Pool Strategically: Use TracePool for repeated operations
  3. Size Appropriately: Pre-size traces when choice count is predictable
  4. Monitor Growth: Watch for memory leaks in long-running processes

Debugging Workflow

Systematic Debugging

  1. Check Basics: Verify all log-weights are finite
  2. Isolate Components: Test prior, likelihood, factors separately
  3. Use Validation: Create traces with known-good parameter values
  4. Compare Algorithms: Try different inference methods
  5. Visualize Traces: Plot parameter trajectories over time

Testing Your Understanding

Exercise 1: Custom Proposal Mechanism

Implement a custom handler that uses adaptive proposals based on the acceptance rate history:

use fugue::*;
use fugue::runtime::{handler::Handler, trace::*};

struct AdaptiveMCMCHandler<R: rand::Rng> {
    rng: R,
    current_trace: Trace,
    proposal_scale: f64,
    acceptance_history: Vec<bool>,
    adaptation_interval: usize,
}

// TODO: Implement Handler trait with adaptive step size

Exercise 2: Multi-Chain Diagnostics

Create a system that runs multiple MCMC chains in parallel and automatically assesses convergence:

use fugue::inference::diagnostics::*;

fn multi_chain_inference<F>(
    model_fn: F,
    n_chains: usize, 
    n_samples: usize
) -> (Vec<Vec<Trace>>, bool)
where F: Fn() -> Model<f64> + Copy
{
    // TODO: Run multiple chains and check R-hat convergence
    unimplemented!()
}

Exercise 3: Memory-Optimized Batch Processing

Design a system for processing thousands of similar models efficiently:

use fugue::runtime::memory::*;

struct BatchProcessor {
    pool: TracePool,
    // TODO: Add fields for efficient batch processing
}

impl BatchProcessor {
    fn process_batch<F>(&mut self, 
                       models: Vec<F>) -> Vec<(f64, Trace)>
    where F: Fn() -> Model<f64>
    {
        // TODO: Implement memory-efficient batch processing
        unimplemented!()
    }
}

Key Takeaways

Trace Manipulation Mastery

  1. Execution History: Traces record complete probabilistic execution paths
  2. Handler Flexibility: The same model can be executed in radically different ways
  3. Replay Foundation: MCMC and other algorithms depend on deterministic replay
  4. Custom Strategies: Implement specialized inference through custom handlers
  5. Production Ready: Memory optimization and diagnostics enable robust deployment
  6. Debugging Power: Trace analysis reveals numerical issues and convergence problems

Core Capabilities:

  • โœ… Complete execution recording with type safety and weight decomposition
  • โœ… Flexible interpretation through the handler system
  • โœ… MCMC foundation via deterministic replay mechanics
  • โœ… Custom inference algorithms through handler extensibility
  • โœ… Production optimization with memory pooling and efficient allocation
  • โœ… Comprehensive diagnostics for convergence assessment and debugging

Further Reading

  • Custom Handlers Guide - Building specialized interpreters
  • Optimizing Performance - Production deployment strategies
  • Debugging Models - Troubleshooting problematic models
  • API Reference - Complete runtime system specification
  • The Elements of Statistical Learning - Theoretical foundations of inference algorithms
  • Monte Carlo Statistical Methods - MCMC theory and practice

Statistical Modeling

Contents

This section provides comprehensive coverage of statistical modeling using Fugue:

Welcome to statistical modeling with Fugue! This section demonstrates how Bayesian probabilistic programming transforms traditional statistical analysis through principled uncertainty quantification, robust inference, and flexible model specification.

Why Bayesian Statistical Modeling?

Traditional statistics gives you point estimates and p-values.
Bayesian modeling gives you full posterior distributions, prediction intervals, and principled model comparison.

โœ… Natural uncertainty quantification for all parameters
โœ… Robust inference with constraint-aware MCMC
โœ… Principled model selection via Bayes factors and information criteria
โœ… Flexible prior knowledge integration through hierarchical structures
โœ… Automatic regularization prevents overfitting
โœ… Production-ready workflows with comprehensive diagnostics

Learning Path

graph TB
    A[Statistical Modeling Journey] --> B[Foundation]
    A --> C[Supervised Learning] 
    A --> D[Unsupervised Learning]
    A --> E[Advanced Methods]

    B --> F["Linear Regression<br/>๐Ÿ“Š Basic Bayesian inference<br/>๐Ÿ“Š Robust methods<br/>๐Ÿ“Š Polynomial models<br/>๐Ÿ“Š Model selection"]

    C --> G["Classification<br/>๐Ÿง  Logistic regression<br/>๐Ÿง  Multi-class methods<br/>๐Ÿง  Hierarchical classification<br/>๐Ÿง  Model comparison"]

    D --> H["Mixture Models<br/>๐Ÿงฌ Gaussian mixtures<br/>๐Ÿงฌ Infinite mixtures<br/>๐Ÿงฌ Hidden Markov models<br/>๐Ÿงฌ Clustering validation"]

    E --> I["Hierarchical Models<br/>๐Ÿข Varying intercepts<br/>๐Ÿข Mixed effects<br/>๐Ÿข Nested structures<br/>๐Ÿข Partial pooling"]

    F --> J["Applications<br/>๐Ÿ”ฌ Economic forecasting<br/>๐Ÿ”ฌ Medical research<br/>๐Ÿ”ฌ Scientific modeling"]
    G --> J
    H --> J
    I --> J

Structured Learning Path

Beginners: Start with Linear Regression โ†’ Classification

Intermediate: Add Mixture Models for unsupervised learning

Advanced: Master Hierarchical Models for complex data structures

All Levels: Each tutorial includes complete working examples with runnable code!

Tutorial Overview

๐Ÿ“Š Linear Regression

The foundation of statistical modeling, covering:

  • Basic Bayesian regression with uncertainty quantification
  • Robust regression for outlier resistance
  • Polynomial regression for nonlinear relationships
  • Model selection using Bayes factors
  • Ridge regression for high-dimensional problems
use fugue::*;
// Example: Basic Bayesian linear regression
let model = prob! {
    let intercept <- sample(addr!("intercept"), Normal::new(0.0, 10.0).unwrap());
    let slope <- sample(addr!("slope"), Normal::new(0.0, 10.0).unwrap());
    let sigma <- sample(addr!("sigma"), Gamma::new(1.0, 1.0).unwrap());
    
    // Observations with uncertainty
    for (i, (x_i, y_i)) in x_data.iter().zip(y_data.iter()).enumerate() {
        let mu_i = intercept + slope * x_i;
        observe(addr!("y", i), Normal::new(mu_i, sigma).unwrap(), *y_i);
    }
    
    pure((intercept, slope, sigma))
};

๐Ÿง  Classification

Discrete outcome modeling with comprehensive coverage:

  • Binary classification with logistic regression
  • Multi-class classification using multinomial methods
  • Hierarchical classification for grouped data
  • Model comparison and performance evaluation
use fugue::*;
// Example: Hierarchical logistic regression
let model = prob! {
    let global_intercept <- sample(addr!("global_intercept"), Normal::new(0.0, 2.0).unwrap());
    let slope <- sample(addr!("slope"), Normal::new(0.0, 2.0).unwrap());
    let group_sigma <- sample(addr!("group_sigma"), Gamma::new(1.0, 1.0).unwrap());
    
    // Group-specific intercepts with partial pooling
    let group_intercepts <- plate!(g in 0..n_groups => {
        sample(addr!("group_intercept", g), Normal::new(global_intercept, group_sigma).unwrap())
    });
    
    pure((global_intercept, slope, group_intercepts))
};

๐Ÿงฌ Mixture Models

Advanced unsupervised learning techniques:

  • Gaussian mixtures for continuous data clustering
  • Multivariate mixtures with correlation structure
  • Infinite mixtures with automatic component discovery
  • Hidden Markov models for temporal clustering
use fugue::*;
// Example: Gaussian mixture with latent variables
let model = prob! {
    let pi1 <- sample(addr!("pi1"), Beta::new(1.0, 1.0).unwrap());
    let mu1 <- sample(addr!("mu1"), Normal::new(0.0, 5.0).unwrap());
    let sigma1 <- sample(addr!("sigma1"), Gamma::new(1.0, 1.0).unwrap());
    
    // Latent cluster assignments
    let assignments <- plate!(i in 0..data.len() => {
        sample(addr!("z", i), Categorical::new(vec![pi1, 1.0 - pi1]).unwrap())
    });
    
    pure((pi1, mu1, sigma1, assignments))
};

๐Ÿข Hierarchical Models

Multi-level modeling for complex data structures:

  • Varying intercepts for group-level baseline differences
  • Varying slopes for group-level relationship differences
  • Mixed effects combining fixed and random effects
  • Nested hierarchies for multi-level clustering
use fugue::*;
// Example: Varying intercepts model
let model = prob! {
    // Population-level hyperparameters
    let mu_alpha <- sample(addr!("mu_alpha"), Normal::new(0.0, 5.0).unwrap());
    let sigma_alpha <- sample(addr!("sigma_alpha"), Gamma::new(1.0, 1.0).unwrap());
    let beta <- sample(addr!("beta"), Normal::new(0.0, 2.0).unwrap());
    
    // Group-specific intercepts via partial pooling
    let _observations <- plate!(i in 0..x_data.len() => {
        let group_j = group_ids[i];
        sample(addr!("alpha", group_j), Normal::new(mu_alpha, sigma_alpha).unwrap())
            .bind(move |alpha_j| {
                let mu_i = alpha_j + beta * x_data[i];
                observe(addr!("y", i), Normal::new(mu_i, sigma_y).unwrap(), y_data[i])
            })
    });
    
    pure((mu_alpha, sigma_alpha, beta, sigma_y))
};

Key Statistical Concepts

Bayesian Inference Pipeline

graph LR
    A["Data<br/>yโ‚, yโ‚‚, ..., yโ‚™"] --> B["Model<br/>p(y|ฮธ)"]
    C["Priors<br/>p(ฮธ)"] --> B

    B --> D["Posterior<br/>p(ฮธ|y) โˆ p(y|ฮธ)p(ฮธ)"]

    D --> E["MCMC Sampling<br/>ฮธโฝยนโพ, ฮธโฝยฒโพ, ..., ฮธโฝแดนโพ"]

    E --> F["Inference<br/>๐Ÿ“Š Point estimates<br/>๐Ÿ“Š Credible intervals<br/>๐Ÿ“Š Predictions"]

    E --> G["Diagnostics<br/>๐Ÿ” Convergence<br/>๐Ÿ” Model checking<br/>๐Ÿ” Validation"]

Core Advantages

Traditional MethodsBayesian Methods
Point estimatesFull posterior distributions
Confidence intervalsCredible intervals
P-valuesBayes factors
Ad-hoc regularizationPrincipled prior specification
Model selection via AIC/BICMarginal likelihood comparison

Practical Implementation

Essential Patterns

Model Structure:

use fugue::*;
let model = prob! {
    // 1. Prior specification
    let parameter <- sample(addr!("param"), Normal::new(0.0, 1.0).unwrap());
    
    // 2. Likelihood specification
    for (i, observation) in data.iter().enumerate() {
        observe(addr!("obs", i), distribution, *observation);
    }
    
    // 3. Return parameters of interest
    pure(parameter)
};

MCMC Workflow:

use fugue::inference::mh::adaptive_mcmc_chain;
use rand::{SeedableRng, rngs::StdRng};

// 1. Define model function
let model_fn = move || your_statistical_model(data.clone());

// 2. Run adaptive MCMC
let mut rng = StdRng::seed_from_u64(42);
let samples = adaptive_mcmc_chain(&mut rng, model_fn, 1000, 200);

// 3. Extract and analyze results
let parameter_samples: Vec<f64> = samples.iter()
    .map(|(params, _)| params.parameter_of_interest)
    .collect();

Model Selection Framework

Information Criteria

CriterionFormulaUse Case
DICGeneral model comparison
WAICRobust alternative to DIC
Bayes FactorDirect model evidence
Cross-ValidationLeave-one-out predictive accuracyOut-of-sample validation

Model Building Strategy

Systematic Model Development

  1. Start Simple: Begin with basic models (e.g., linear regression)
  2. Add Complexity Gradually: Introduce robustness, nonlinearity, hierarchy as needed
  3. Compare Systematically: Use information criteria and cross-validation
  4. Validate Thoroughly: Check residuals, convergence, and predictive performance
  5. Document Assumptions: Clearly state model assumptions and limitations

Running the Examples

All tutorials include complete, runnable examples:

# Linear regression demonstrations
cargo run --example linear_regression

# Classification methods
cargo run --example classification  

# Mixture modeling techniques
cargo run --example mixture_models

# Hierarchical model applications
cargo run --example hierarchical_models

# Run all statistical modeling tests
cargo test --example linear_regression
cargo test --example classification
cargo test --example mixture_models  
cargo test --example hierarchical_models

Production Deployment

Scalability Considerations

use fugue::*;

// For large datasets, consider:

// 1. Mini-batch processing
fn minibatch_mcmc(data_chunks: Vec<Vec<Data>>, model_fn: ModelFn) {
    for chunk in data_chunks {
        let samples = adaptive_mcmc_chain(&mut rng, || model_fn(chunk), n_samples, warmup);
        // Process samples...
    }
}

// 2. Parallel inference
use std::thread;
let handles: Vec<_> = (0..n_chains).map(|chain_id| {
    thread::spawn(move || {
        let mut rng = StdRng::seed_from_u64(chain_id as u64);
        adaptive_mcmc_chain(&mut rng, model_fn, n_samples, warmup)
    })
}).collect();

// 3. Streaming inference for real-time data

Monitoring and Diagnostics

Production Checklist

Essential monitoring for production Bayesian models:

โœ… Convergence diagnostics: R-hat < 1.1, effective sample size > 100
โœ… Prior sensitivity: Results stable across reasonable prior choices
โœ… Posterior predictive checks: Model captures key data features
โœ… Cross-validation: Stable out-of-sample performance
โœ… Computational efficiency: Reasonable wall-clock time for inference
โœ… Parameter stability: Results consistent across multiple runs

Mathematical Foundations

Core Statistical Models

Linear Models:

Generalized Linear Models:

Hierarchical Models:

Mixture Models:

Bayesian Workflow

graph TB
    A[Domain Problem] --> B[Statistical Question]
    B --> C[Model Specification]

    C --> D[Prior Elicitation]
    C --> E[Likelihood Choice]  
    C --> F[Parameter Structure]

    D --> G[Posterior Inference]
    E --> G
    F --> G

    G --> H[MCMC Sampling]
    H --> I[Convergence Diagnostics]

    I --> J{Converged?}
    J -->|No| K[Adjust Model/Priors] --> C
    J -->|Yes| L[Model Checking]

    L --> M{Model Adequate?}
    M -->|No| N[Refine Model] --> C
    M -->|Yes| O[Scientific Inference]

    O --> P[Decision/Action]

Advanced Topics

Model Extensions

Each tutorial demonstrates advanced extensions:

  • Robustness: Heavy-tailed distributions, outlier modeling
  • Nonlinearity: Polynomial basis, spline methods, kernels
  • Correlation: Multivariate models, spatial/temporal correlation
  • Hierarchical Structure: Multi-level, nested, cross-classified models
  • Model Uncertainty: Averaging, selection, expansion

Computational Methods

use fugue::*;

// Advanced MCMC techniques demonstrated:

// 1. Constraint-aware proposals for positive parameters
// Automatically handled by Fugue's MCMC implementation

// 2. Adaptive MCMC for efficient exploration
let samples = adaptive_mcmc_chain(&mut rng, model_fn, n_samples, warmup);

// 3. Multiple chains for convergence assessment  
let chains: Vec<_> = (0..n_chains).map(|seed| {
    let mut rng = StdRng::seed_from_u64(seed as u64);
    adaptive_mcmc_chain(&mut rng, model_fn.clone(), n_samples, warmup)
}).collect();

// 4. Posterior predictive sampling
let predictions: Vec<f64> = samples.iter().map(|(params, _)| {
    // Generate predictions using posterior samples
    predictive_model(new_x, params)
}).collect();

Integration with Fugue Ecosystem

Type Safety Benefits

use fugue::*;

// Fugue's type safety prevents common statistical errors:

let bernoulli = Bernoulli::new(0.7).unwrap();
let outcome: bool = bernoulli.sample(&mut rng);  // Returns bool, not int!

let categorical = Categorical::new(vec![0.2, 0.3, 0.5]).unwrap(); 
let class: usize = categorical.sample(&mut rng);  // Safe indexing!

let normal = Normal::new(0.0, 1.0).unwrap();
let value: f64 = normal.sample(&mut rng);  // Explicit numeric type!

Runtime Integration

use fugue::runtime::handler::run;
use fugue::runtime::interpreters::PriorHandler;

// Seamless integration with Fugue's runtime system:

// 1. Prior sampling for model validation
let (result, trace) = run(
    PriorHandler { rng: &mut rng, trace: Trace::default() },
    your_model()
);

// 2. Scoring for model comparison
let scored_trace = ScoreGivenTrace::new(trace).score(&mut rng, your_model());

// 3. Replay for debugging
let replay_trace = ReplayHandler::new(previous_trace)
    .replay(&mut rng, your_model());

Common Statistical Tasks

1. Parameter Estimation

  • Point estimates via posterior means
  • Uncertainty via credible intervals
  • Hypothesis testing via posterior probabilities

2. Prediction

  • Point predictions with uncertainty bands
  • Posterior predictive distributions
  • Out-of-sample validation

3. Model Selection

  • Information criteria comparison
  • Bayes factor evidence assessment
  • Cross-validation performance

4. Model Checking

  • Posterior predictive checks
  • Residual analysis
  • Convergence diagnostics

Best Practices Summary

๐ŸŽฏ Model Building: Start simple, add complexity gradually, validate thoroughly

๐Ÿ”ฌ Prior Selection: Use weakly informative priors, check sensitivity

๐Ÿ“Š Inference: Monitor convergence, assess adequacy, quantify uncertainty

๐Ÿš€ Production: Automate diagnostics, cache samples, monitor performance

๐Ÿ“š Communication: Visualize uncertainty, explain methodology, document assumptions

Further Reading

Fugue Documentation

Statistical References

  • Gelman et al. "Bayesian Data Analysis" - Comprehensive Bayesian statistics
  • McElreath "Statistical Rethinking" - Modern computational approach
  • Kruschke "Doing Bayesian Data Analysis" - Applied Bayesian methods
  • Murphy "Machine Learning: A Probabilistic Perspective" - ML and statistics integration

Advanced Topics


Statistical modeling with Fugue combines the theoretical rigor of Bayesian inference with the practical advantages of type-safe probabilistic programming. Whether you're analyzing experimental data, building predictive models, or exploring complex relationships, these tutorials provide the foundation for principled, robust, and scalable statistical analysis.

๐ŸŽ“ Start your statistical modeling journey with the tutorial that matches your current needs and experience level!

Linear Regression

A comprehensive guide to Bayesian linear regression using Fugue. This tutorial demonstrates how to build, analyze, and extend linear models for real-world data analysis, showcasing the power of probabilistic programming for uncertainty quantification and model comparison.

Learning Objectives

By the end of this tutorial, you will understand:

  • Bayesian Linear Regression: Prior specification and posterior inference for regression parameters
  • Uncertainty Quantification: How to extract and interpret parameter uncertainty from MCMC samples
  • Robust Regression: Using heavy-tailed distributions to handle outliers
  • Polynomial Regression: Modeling nonlinear relationships with polynomial basis functions
  • Model Selection: Bayesian methods for comparing regression models
  • Regularization: Ridge regression through hierarchical priors
  • Production Applications: Scalable inference for high-dimensional regression problems

The Linear Regression Framework

Linear regression is the cornerstone of statistical modeling. In the Bayesian framework, we treat regression parameters as random variables with prior distributions, allowing us to quantify uncertainty in our estimates and make probabilistic predictions.

graph TB
    A["Data: (xโ‚,yโ‚), (xโ‚‚,yโ‚‚), ..., (xโ‚™,yโ‚™)"] --> B["Linear Model<br/>y = ฮฒโ‚€ + ฮฒโ‚x + ฮต"]

    B --> C["Bayesian Framework"]
    C --> D["Prior: p(ฮฒโ‚€, ฮฒโ‚, ฯƒ)"]
    C --> E["Likelihood: p(y|X, ฮฒโ‚€, ฮฒโ‚, ฯƒ)"]

    D --> F["Posterior: p(ฮฒโ‚€, ฮฒโ‚, ฯƒ|y, X)"]
    E --> F

    F --> G["MCMC Sampling"]
    G --> H["Parameter Uncertainty"]
    G --> I["Predictive Distribution"]
    G --> J["Model Comparison"]

Mathematical Foundation

Basic Linear Model

The fundamental linear regression model is:

where:

  • : Response variable (dependent)
  • : Predictor variable (independent)
  • : Intercept parameter
  • : Slope parameter
  • : Random error

Bayesian Specification

Prior Distributions:

  • or

Likelihood:

Posterior:

Conjugate Analysis

When using conjugate priors (Normal-Inverse-Gamma), the posterior has a closed form. However, MCMC allows us to use more flexible priors and handle complex models without conjugacy restrictions.

Basic Linear Regression

Let's start with the fundamental case: simple linear regression with one predictor variable.

Implementation

use fugue::*;
use fugue::runtime::interpreters::PriorHandler;
use rand::{SeedableRng, rngs::StdRng};
// Basic Bayesian linear regression model
fn basic_linear_regression_model(x_data: Vec<f64>, y_data: Vec<f64>) -> Model<(f64, f64, f64)> {
    prob! {
        let intercept <- sample(addr!("intercept"), Normal::new(0.0, 10.0).unwrap());
        let slope <- sample(addr!("slope"), Normal::new(0.0, 10.0).unwrap());

        // Use a well-behaved prior for sigma (now that MCMC handles positivity constraints)
        let sigma <- sample(addr!("sigma"), Gamma::new(1.0, 1.0).unwrap()); // Mean = 1, more concentrated

        // Simple observations (limited number for efficiency)
        let _obs_0 <- observe(addr!("y", 0), Normal::new(intercept + slope * x_data[0], sigma).unwrap(), y_data[0]);
        let _obs_1 <- observe(addr!("y", 1), Normal::new(intercept + slope * x_data[1], sigma).unwrap(), y_data[1]);
        let _obs_2 <- observe(addr!("y", 2), Normal::new(intercept + slope * x_data[2], sigma).unwrap(), y_data[2]);

        pure((intercept, slope, sigma))
    }
}

fn basic_regression_demo() {
    println!("=== Basic Linear Regression ===\n");

    // Generate synthetic data: y = 2 + 1.5*x + noise (smaller dataset for demo)
    let (x_data, y_data) = generate_regression_data(20, 1.5, 2.0, 0.5, 12345);

    println!("๐Ÿ“Š Generated {} data points", x_data.len());
    println!("   - True intercept: 2.0, True slope: 1.5, True sigma: 0.5");
    println!(
        "   - Data range: x โˆˆ [{:.1}, {:.1}], y โˆˆ [{:.1}, {:.1}]",
        x_data[0],
        x_data[x_data.len() - 1],
        y_data.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
        y_data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
    );

    // Create model function that uses the data
    let model_fn = move || basic_linear_regression_model(x_data.clone(), y_data.clone());

    println!("\n๐Ÿ”ฌ Running MCMC inference...");
    let mut rng = StdRng::seed_from_u64(42);
    let samples = adaptive_mcmc_chain(&mut rng, model_fn, 500, 100);

    // Extract parameter estimates
    let intercepts: Vec<f64> = samples
        .iter()
        .filter_map(|(_, trace)| trace.get_f64(&addr!("intercept")))
        .collect();
    let slopes: Vec<f64> = samples
        .iter()
        .filter_map(|(_, trace)| trace.get_f64(&addr!("slope")))
        .collect();
    let sigmas: Vec<f64> = samples
        .iter()
        .filter_map(|(_, trace)| trace.get_f64(&addr!("sigma")))
        .collect();

    if !intercepts.is_empty() && !slopes.is_empty() && !sigmas.is_empty() {
        println!("โœ… MCMC completed with {} samples", samples.len());
        println!("\n๐Ÿ“ˆ Parameter Estimates:");

        let mean_intercept = intercepts.iter().sum::<f64>() / intercepts.len() as f64;
        let mean_slope = slopes.iter().sum::<f64>() / slopes.len() as f64;
        let mean_sigma = sigmas.iter().sum::<f64>() / sigmas.len() as f64;

        println!("   - Intercept: {:.3} (true: 2.0)", mean_intercept);
        println!("   - Slope: {:.3} (true: 1.5)", mean_slope);
        println!("   - Sigma: {:.3} (true: 0.5)", mean_sigma);

        // Show some diagnostics
        let valid_traces = samples
            .iter()
            .filter(|(_, trace)| trace.total_log_weight().is_finite())
            .count();
        println!("   - Valid traces: {} / {}", valid_traces, samples.len());
    } else {
        println!("โŒ MCMC failed - no valid samples obtained");
    }
    println!();
}

Key Concepts

  1. Prior Specification: We use weakly informative priors that allow the data to dominate
  2. Vectorized Likelihood: The for loop handles multiple observations efficiently
  3. Parameter Recovery: MCMC estimates should recover true parameter values
  4. Uncertainty Quantification: Standard deviations provide parameter uncertainty

Prior Selection

  • Use Normal(0, 10) for regression coefficients when predictors are standardized
  • Use Gamma(2, 0.5) for error variance (ฯƒ) - gives reasonable prior mass over positive values
  • Adjust prior scale based on your domain knowledge and data scale

Interpretation

The posterior samples provide:

  • Point Estimates: Posterior means are Bayesian parameter estimates
  • Credible Intervals: Quantiles give uncertainty bounds (e.g., 95% credible intervals)
  • Predictive Distribution: For new :

Robust Regression

Standard linear regression assumes Gaussian errors, making it sensitive to outliers. Robust regression uses heavy-tailed distributions to reduce outlier influence.

Theory

Replace the normal likelihood with a t-distribution:

where:

  • (linear predictor)
  • : Degrees of freedom (lower = heavier tails)
  • As , (normal distribution)

Implementation

use fugue::*;
// Robust regression using t-distribution for outlier resistance
fn robust_regression_model(x_data: Vec<f64>, y_data: Vec<f64>) -> Model<(f64, f64, f64, f64)> {
    prob! {
        let intercept <- sample(addr!("intercept"), Normal::new(0.0, 10.0).unwrap());
        let slope <- sample(addr!("slope"), Normal::new(0.0, 10.0).unwrap());
        let sigma <- sample(addr!("sigma"), Gamma::new(2.0, 0.5).unwrap());
        let nu <- sample(addr!("nu"), Gamma::new(2.0, 0.1).unwrap()); // Degrees of freedom for t-dist

        // Use plate notation for observations
        let _observations <- plate!(i in x_data.iter().zip(y_data.iter()).enumerate().take(3) => {
            let (idx, (x_i, y_i)) = i;
            observe(addr!("y", idx), Normal::new(intercept + slope * x_i, sigma).unwrap(), *y_i)
        });

        pure((intercept, slope, sigma, nu))
    }
}

fn robust_regression_demo() {
    println!("=== Robust Linear Regression ===\n");

    // Generate data with outliers
    let (mut x_data, mut y_data) = generate_regression_data(40, 1.2, 3.0, 0.4, 67890);

    // Add some outliers
    x_data.extend(vec![8.5, 9.2, 7.8]);
    y_data.extend(vec![20.0, -5.0, 25.0]); // Clear outliers

    println!(
        "๐Ÿ“Š Generated {} data points (with 3 outliers)",
        x_data.len()
    );
    println!("   - Base relationship: y = 3.0 + 1.2*x + noise");
    println!("   - Added outliers at x=[8.5, 9.2, 7.8] with y=[20.0, -5.0, 25.0]");

    // Compare standard vs robust regression
    let mut rng = StdRng::seed_from_u64(42);

    // Standard regression
    println!("\n๐Ÿ”ฌ Standard Linear Regression:");
    let standard_model_fn = || basic_linear_regression_model(x_data.clone(), y_data.clone());
    let standard_samples = adaptive_mcmc_chain(&mut rng, standard_model_fn, 500, 100);

    let std_intercepts: Vec<f64> = standard_samples
        .iter()
        .map(|(_, trace)| trace.get_f64(&addr!("intercept")).unwrap())
        .collect();
    let std_slopes: Vec<f64> = standard_samples
        .iter()
        .map(|(_, trace)| trace.get_f64(&addr!("slope")).unwrap())
        .collect();

    println!(
        "   - Intercept: {:.3} (true: 3.0)",
        std_intercepts.iter().sum::<f64>() / std_intercepts.len() as f64
    );
    println!(
        "   - Slope: {:.3} (true: 1.2)",
        std_slopes.iter().sum::<f64>() / std_slopes.len() as f64
    );

    // Robust regression (conceptual - using same likelihood but different prior structure)
    println!("\n๐Ÿ›ก๏ธ Robust Regression (Conceptual):");
    let mut rng2 = StdRng::seed_from_u64(42);
    let robust_model_fn = || robust_regression_model(x_data.clone(), y_data.clone());
    let robust_samples = adaptive_mcmc_chain(&mut rng2, robust_model_fn, 500, 100);

    let rob_intercepts: Vec<f64> = robust_samples
        .iter()
        .map(|(_, trace)| trace.get_f64(&addr!("intercept")).unwrap())
        .collect();
    let rob_slopes: Vec<f64> = robust_samples
        .iter()
        .map(|(_, trace)| trace.get_f64(&addr!("slope")).unwrap())
        .collect();
    let rob_nus: Vec<f64> = robust_samples
        .iter()
        .map(|(_, trace)| trace.get_f64(&addr!("nu")).unwrap())
        .collect();

    println!(
        "   - Intercept: {:.3} (true: 3.0)",
        rob_intercepts.iter().sum::<f64>() / rob_intercepts.len() as f64
    );
    println!(
        "   - Slope: {:.3} (true: 1.2)",
        rob_slopes.iter().sum::<f64>() / rob_slopes.len() as f64
    );
    println!(
        "   - Degrees of freedom (ฮฝ): {:.3}",
        rob_nus.iter().sum::<f64>() / rob_nus.len() as f64
    );

    println!("\n๐Ÿ’ก Note: Lower ฮฝ indicates heavier tails (more robust to outliers)");
    println!();
}

Robust vs. Standard Comparison

graph LR
    A["Data with Outliers"] --> B["Standard Regression"]
    A --> C["Robust Regression"]

    B --> D["Biased Parameters<br/>Large Residuals"]
    C --> E["Stable Parameters<br/>Heavy-tailed Errors"]

Advantages of Robust Regression:

  • Outlier Resistance: Heavy tails accommodate extreme values
  • Automatic Detection: Low indicates outlier presence
  • Flexible: Reduces to normal regression when is large

Computational Complexity

t-distribution likelihoods are more computationally expensive than normal distributions. For very large datasets, consider preprocessing to remove obvious outliers first.

Polynomial Regression

Linear regression can model nonlinear relationships using polynomial basis functions:

Mathematical Framework

Design Matrix: For polynomial degree :

Hierarchical Prior: Control overfitting with shrinkage priors:

Implementation

use fugue::*;
// Polynomial regression with automatic relevance determination
fn polynomial_regression_model(
    x_data: Vec<f64>,
    y_data: Vec<f64>,
    _degree: usize,
) -> Model<Vec<f64>> {
    prob! {
        // Hierarchical prior for polynomial coefficients
        let precision <- sample(addr!("precision"), Gamma::new(2.0, 1.0).unwrap());

        // Sample polynomial coefficients (fixed degree for simplicity)
        let coef_0 <- sample(addr!("coef", 0), Normal::new(0.0, 1.0 / precision.sqrt()).unwrap());
        let coef_1 <- sample(addr!("coef", 1), Normal::new(0.0, 1.0 / precision.sqrt()).unwrap());
        let coef_2 <- sample(addr!("coef", 2), Normal::new(0.0, 1.0 / precision.sqrt()).unwrap());
        let coefficients = vec![coef_0, coef_1, coef_2];

        // Noise parameter
        let sigma <- sample(addr!("sigma"), Gamma::new(2.0, 0.5).unwrap());

        // Clone coefficients for use in closure
        let coefficients_for_observations = coefficients.clone();
        let _observations <- plate!(i in x_data.iter().zip(y_data.iter()).enumerate().take(3) => {
            let (idx, (x_i, y_i)) = i;
            let mut mean_i = 0.0;
            for (d, coef) in coefficients_for_observations.iter().enumerate() {
                mean_i += coef * x_i.powi(d as i32);
            }
            observe(addr!("y", idx), Normal::new(mean_i, sigma).unwrap(), *y_i)
        });

        pure(coefficients)
    }
}

fn polynomial_regression_demo() {
    println!("=== Polynomial Regression ===\n");

    // Generate nonlinear data: y = 1 + 2x - 0.5xยฒ + noise
    let x_raw: Vec<f64> = (0..30).map(|i| i as f64 / 29.0 * 4.0).collect(); // x from 0 to 4
    let y_data: Vec<f64> = x_raw
        .iter()
        .map(|&x| {
            let true_mean = 1.0 + 2.0 * x - 0.5 * x.powi(2);
            let mut rng = StdRng::seed_from_u64(((x * 1000.0) as u64) + 555);
            true_mean + Normal::new(0.0, 0.3).unwrap().sample(&mut rng)
        })
        .collect();

    println!("๐Ÿ“Š Generated nonlinear data: y = 1 + 2x - 0.5xยฒ + noise");
    println!("   - {} data points, x โˆˆ [0, 4]", x_raw.len());

    // Fit polynomial models of different degrees
    for degree in [1, 2, 3].iter() {
        println!("\n๐Ÿ”ฌ Fitting degree {} polynomial...", degree);

        let mut rng = StdRng::seed_from_u64(42 + *degree as u64);
        let model_fn = || polynomial_regression_model(x_raw.clone(), y_data.clone(), *degree);
        let samples = adaptive_mcmc_chain(&mut rng, model_fn, 400, 80);

        println!("   Coefficient estimates:");
        for d in 0..=*degree {
            let coef_samples: Vec<f64> = samples
                .iter()
                .map(|(_, trace)| trace.get_f64(&addr!("coef", d)).unwrap())
                .collect();
            let mean_coef = coef_samples.iter().sum::<f64>() / coef_samples.len() as f64;

            let true_coef = match d {
                0 => 1.0,  // intercept
                1 => 2.0,  // linear term
                2 => -0.5, // quadratic term
                _ => 0.0,  // higher terms should be ~0
            };

            println!("     x^{}: {:.3} (true: {:.1})", d, mean_coef, true_coef);
        }

        // Model comparison metric (simplified log marginal likelihood)
        let log_likelihoods: Vec<f64> = samples
            .iter()
            .map(|(_, trace)| trace.log_likelihood)
            .collect();
        let avg_log_likelihood = log_likelihoods.iter().sum::<f64>() / log_likelihoods.len() as f64;
        println!("     Average log-likelihood: {:.2}", avg_log_likelihood);
    }

    println!("\n๐Ÿ’ก The degree-2 polynomial should have the highest likelihood!");
    println!();
}

Overfitting Prevention

graph TD
    A["Polynomial Degree"] --> B["Model Complexity"]
    B --> C{Degree Choice}

    C -->|Too Low| D["Underfitting<br/>High Bias"]
    C -->|Just Right| E["Good Fit<br/>Balanced"]
    C -->|Too High| F["Overfitting<br/>High Variance"]

    G["Hierarchical Priors"] --> H["Automatic Shrinkage"]
    H --> E

Shrinkage Benefits:

  • Automatic Regularization: Higher-order terms shrink toward zero
  • Bias-Variance Tradeoff: Balances model flexibility with stability
  • Model Selection: Coefficients near zero indicate irrelevant terms

Bayesian Model Selection

Compare different regression models using marginal likelihood and information criteria.

Model Comparison Framework

For models :

Marginal Likelihood:

Bayes Factors:

Model Posterior Probabilities:

Implementation

use fugue::*;
// Bayesian model selection for regression
#[derive(Clone, Copy, Debug)]
enum RegressionModel {
    Linear,
    Quadratic,
    Cubic,
}

fn model_selection_demo() {
    println!("=== Bayesian Model Selection ===\n");

    // Generate quadratic data
    let x_data: Vec<f64> = (0..25).map(|i| (i as f64 - 12.0) / 5.0).collect(); // x from -2.4 to 2.4
    let y_data: Vec<f64> = x_data
        .iter()
        .map(|&x| {
            let true_mean = 0.5 + 1.5 * x - 0.8 * x.powi(2);
            let mut rng = StdRng::seed_from_u64(((x.abs() * 1000.0) as u64) + 777);
            true_mean + Normal::new(0.0, 0.2).unwrap().sample(&mut rng)
        })
        .collect();

    println!("๐Ÿ“Š True model: y = 0.5 + 1.5x - 0.8xยฒ + noise");

    let models = [
        (RegressionModel::Linear, 1),
        (RegressionModel::Quadratic, 2),
        (RegressionModel::Cubic, 3),
    ];

    let mut model_scores = Vec::new();

    for (model_type, degree) in models.iter() {
        println!("\n๐Ÿ”ฌ Evaluating {:?} model...", model_type);

        let mut rng = StdRng::seed_from_u64(42 + *degree as u64);
        let model_fn = || polynomial_regression_model(x_data.clone(), y_data.clone(), *degree);
        let samples = adaptive_mcmc_chain(&mut rng, model_fn, 300, 60);

        // Compute approximate marginal likelihood (harmonic mean estimator)
        let log_likelihoods: Vec<f64> = samples
            .iter()
            .map(|(_, trace)| trace.log_likelihood)
            .collect();

        let max_ll = log_likelihoods
            .iter()
            .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
        let shifted_lls: Vec<f64> = log_likelihoods.iter().map(|ll| ll - max_ll).collect();
        let mean_exp_ll =
            shifted_lls.iter().map(|ll| ll.exp()).sum::<f64>() / shifted_lls.len() as f64;
        let marginal_log_likelihood = max_ll + mean_exp_ll.ln();

        model_scores.push((*model_type, marginal_log_likelihood));

        println!(
            "   - Marginal log-likelihood: {:.2}",
            marginal_log_likelihood
        );

        // Show coefficient estimates
        for d in 0..=*degree {
            let coef_samples: Vec<f64> = samples
                .iter()
                .map(|(_, trace)| trace.get_f64(&addr!("coef", d)).unwrap())
                .collect();
            let mean_coef = coef_samples.iter().sum::<f64>() / coef_samples.len() as f64;
            println!("     Coefficient x^{}: {:.3}", d, mean_coef);
        }
    }

    // Find best model
    model_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());

    println!("\n๐Ÿ† Model Ranking:");
    for (i, (model, score)) in model_scores.iter().enumerate() {
        let relative_score = score - model_scores[0].1;
        println!(
            "   {}. {:?}: {:.2} (ฮ” = {:.2})",
            i + 1,
            model,
            score,
            relative_score
        );
    }

    println!("\n๐Ÿ’ก The Quadratic model should win (matches true data generating process)!");
    println!();
}

Model Selection Criteria

CriterionFormulaInterpretation
Marginal LikelihoodHigher is better
Bayes Factor> 3: strong evidence for
DICLower is better
WAICLower is better

Model Selection Guidelines

  1. Start Simple: Begin with linear models, add complexity as needed
  2. Cross-Validation: Use holdout data to validate model predictions
  3. Domain Knowledge: Consider scientific plausibility, not just statistical fit
  4. Multiple Criteria: Don't rely on a single selection criterion

Regularized Regression

High-dimensional regression requires regularization to prevent overfitting. Ridge regression achieves this through hierarchical priors.

Ridge Regression Theory

Penalty Formulation:

Bayesian Equivalent:

The regularization parameter controls shrinkage:

  • Large : Strong shrinkage (high bias, low variance)
  • Small : Weak shrinkage (low bias, high variance)

Implementation

use fugue::*;
// Ridge regression (L2 regularization) through hierarchical priors
fn ridge_regression_model(x_data: Vec<Vec<f64>>, y_data: Vec<f64>, lambda: f64) -> Model<Vec<f64>> {
    let p = x_data[0].len(); // number of features

    prob! {
        // Sample coefficients with ridge penalty
        let beta_0 <- sample(addr!("beta", 0), Normal::new(0.0, 1.0 / lambda.sqrt()).unwrap());
        let beta_1 <- sample(addr!("beta", 1), Normal::new(0.0, 1.0 / lambda.sqrt()).unwrap());
        let beta_2 <- sample(addr!("beta", 2), Normal::new(0.0, 1.0 / lambda.sqrt()).unwrap());
        let coefficients = vec![beta_0, beta_1, beta_2];

        let sigma <- sample(addr!("sigma"), Gamma::new(2.0, 0.5).unwrap());

        // Clone coefficients for use in closure
        let coefficients_for_observations = coefficients.clone();
        let _observations <- plate!(i in x_data.iter().zip(y_data.iter()).enumerate().take(2) => {
            let (idx, (x_i, y_i)) = i;
            let mut mean_i = 0.0;
            for (j, beta_j) in coefficients_for_observations.iter().enumerate() {
                if j < p && j < x_i.len() {
                    mean_i += beta_j * x_i[j];
                }
            }
            observe(addr!("y", idx), Normal::new(mean_i, sigma).unwrap(), *y_i)
        });

        pure(coefficients)
    }
}

fn regularized_regression_demo() {
    println!("=== Regularized Regression (Ridge) ===\n");

    // Generate high-dimensional data with few relevant features
    let n = 40;
    let p = 8; // 8 features, only 3 are relevant

    let mut x_data = Vec::new();
    let mut y_data = Vec::new();

    let true_coefs = [2.0, -1.5, 0.0, 1.2, 0.0, 0.0, 0.0, -0.8]; // Only indices 0,1,3,7 matter

    for i in 0..n {
        let mut rng = StdRng::seed_from_u64(1000 + i as u64);
        let x_i: Vec<f64> = (0..p)
            .map(|_| Normal::new(0.0, 1.0).unwrap().sample(&mut rng))
            .collect();

        let true_mean: f64 = x_i.iter().zip(true_coefs.iter()).map(|(x, c)| x * c).sum();
        let y_i = true_mean + Normal::new(0.0, 0.5).unwrap().sample(&mut rng);

        x_data.push(x_i);
        y_data.push(y_i);
    }

    println!("๐Ÿ“Š High-dimensional regression:");
    println!("   - {} observations, {} features", n, p);
    println!("   - True coefficients: [2.0, -1.5, 0.0, 1.2, 0.0, 0.0, 0.0, -0.8]");
    println!("   - Only 4 out of 8 features are relevant");

    // Compare different regularization strengths
    let lambdas = [0.1, 1.0, 10.0];

    for &lambda in lambdas.iter() {
        println!("\n๐Ÿ”ฌ Ridge regression with ฮป = {}:", lambda);

        let mut rng = StdRng::seed_from_u64(42 + (lambda * 100.0) as u64);
        let model_fn = || ridge_regression_model(x_data.clone(), y_data.clone(), lambda);
        let samples = adaptive_mcmc_chain(&mut rng, model_fn, 300, 60);

        println!("   Coefficient estimates (true values in parentheses):");
        for (j, &true_coef) in true_coefs.iter().enumerate().take(p) {
            let coef_samples: Vec<f64> = samples
                .iter()
                .map(|(_, trace)| trace.get_f64(&addr!("beta", j)).unwrap())
                .collect();
            let mean_coef = coef_samples.iter().sum::<f64>() / coef_samples.len() as f64;
            println!("     ฮฒ{}: {:6.3} ({:5.1})", j, mean_coef, true_coef);
        }

        // Compute prediction accuracy (simplified)
        let predictions: Vec<f64> = x_data
            .iter()
            .map(|x_i| {
                let mut pred = 0.0;
                for (j, &x_val) in x_i.iter().enumerate().take(p) {
                    let coef_samples: Vec<f64> = samples
                        .iter()
                        .map(|(_, trace)| trace.get_f64(&addr!("beta", j)).unwrap())
                        .collect();
                    let mean_coef = coef_samples.iter().sum::<f64>() / coef_samples.len() as f64;
                    pred += mean_coef * x_val;
                }
                pred
            })
            .collect();

        let mse = y_data
            .iter()
            .zip(predictions.iter())
            .map(|(y, pred)| (y - pred).powi(2))
            .sum::<f64>()
            / n as f64;

        println!("   - Mean Squared Error: {:.4}", mse);
    }

    println!("\n๐Ÿ’ก Higher ฮป shrinks coefficients toward zero (regularization effect)");
    println!("   Optimal ฮป balances bias-variance tradeoff!");
    println!();
}

Regularization Effects

graph TB
    A["High-Dimensional Data<br/>p >> n"] --> B["Regularization"]

    B --> C["ฮป = 0.1<br/>Weak Shrinkage"]
    B --> D["ฮป = 1.0<br/>Moderate Shrinkage"]
    B --> E["ฮป = 10.0<br/>Strong Shrinkage"]

    C --> F["Low Bias<br/>High Variance"]
    D --> G["Balanced<br/>Optimal MSE"]
    E --> H["High Bias<br/>Low Variance"]

Advantages of Bayesian Ridge:

  • Automatic ฮป Selection: Through hierarchical priors on precision
  • Uncertainty Quantification: Full posterior for all parameters
  • Feature Selection: Coefficients with narrow posteriors around zero

Advanced Extensions

Hierarchical Linear Models

use fugue::*;

// Group-level regression with varying intercepts
fn hierarchical_regression_model(
    x_data: Vec<f64>,
    y_data: Vec<f64>,
    group_ids: Vec<usize>
) -> Model<(f64, f64, Vec<f64>)> {
    let n_groups = group_ids.iter().max().unwrap() + 1;

    prob!(
        // Global parameters
        let global_slope <- sample(addr!("global_slope"), Normal::new(0.0, 5.0).unwrap());
        let sigma_y <- sample(addr!("sigma_y"), Gamma::new(2.0, 0.5).unwrap());
        let sigma_group <- sample(addr!("sigma_group"), Gamma::new(2.0, 1.0).unwrap());

        // Group-specific intercepts
        let mut group_intercepts = Vec::new();
        for g in 0..n_groups {
            let intercept_g <- sample(
                addr!("intercept", g),
                Normal::new(0.0, sigma_group).unwrap()
            );
            group_intercepts.push(intercept_g);
        }

        // Likelihood
        for (i, ((x_i, y_i), group_i)) in x_data.iter()
            .zip(y_data.iter())
            .zip(group_ids.iter())
            .enumerate()
        {
            let mean_i = group_intercepts[*group_i] + global_slope * x_i;
            let _obs <- observe(addr!("y", i), Normal::new(mean_i, sigma_y).unwrap(), *y_i);
        }

        pure((global_slope, sigma_y, group_intercepts))
    )
}

Spline Regression

use fugue::*;

// Bayesian cubic spline regression
fn spline_regression_model(
    x_data: Vec<f64>,
    y_data: Vec<f64>,
    knots: Vec<f64>
) -> Model<Vec<f64>> {
    let n_basis = knots.len() + 3; // Cubic splines

    prob!(
        // Smoothness prior
        let precision <- sample(addr!("precision"), Gamma::new(1.0, 0.1).unwrap());

        // Spline coefficients with smoothness penalty
        let mut coefficients = Vec::new();
        for j in 0..n_basis {
            let coef_j <- sample(
                addr!("coef", j),
                Normal::new(0.0, 1.0 / precision.sqrt()).unwrap()
            );
            coefficients.push(coef_j);
        }

        let sigma <- sample(addr!("sigma"), Gamma::new(2.0, 0.5).unwrap());

        // Likelihood (basis functions would be computed here)
        for (i, (x_i, y_i)) in x_data.iter().zip(y_data.iter()).enumerate() {
            // Compute basis function values at x_i
            let mut mean_i = 0.0;
            for (j, coef_j) in coefficients.iter().enumerate() {
                // basis_function(x_i, j, knots) would compute B-spline basis
                let basis_val = if j < knots.len() {
                    (x_i - knots[j]).max(0.0).powi(3)
                } else {
                    x_i.powi(j - knots.len())
                };
                mean_i += coef_j * basis_val;
            }
            let _obs <- observe(addr!("y", i), Normal::new(mean_i, sigma).unwrap(), *y_i);
        }

        pure(coefficients)
    )
}

Production Considerations

Scalability

For large datasets:

  1. Minibatch MCMC: Use data subsets for likelihood computation
  2. Variational Inference: Approximate posterior for faster computation
  3. GPU Acceleration: Vectorized operations on GPU
  4. Sparse Representations: Efficient storage for high-dimensional sparse data

Model Diagnostics

Essential checks for regression models:

use fugue::inference::diagnostics::*;

fn regression_diagnostics(samples: &[(f64, f64, f64)], x_data: &[f64], y_data: &[f64]) {
    // Residual analysis
    let predictions: Vec<f64> = samples.iter().map(|(intercept, slope, _)| {
        x_data.iter().map(|&x| intercept + slope * x).collect::<Vec<_>>()
    }).flatten().collect();

    // Compute residuals
    let residuals: Vec<f64> = y_data.iter().zip(predictions.iter())
        .map(|(y, pred)| y - pred).collect();

    // Check for patterns in residuals
    println!("Residual diagnostics:");
    println!("  Mean residual: {:.4}", residuals.iter().sum::<f64>() / residuals.len() as f64);
    println!("  Residual std: {:.4}", {
        let mean = residuals.iter().sum::<f64>() / residuals.len() as f64;
        (residuals.iter().map(|r| (r - mean).powi(2)).sum::<f64>() / (residuals.len() - 1) as f64).sqrt()
    });
}

Cross-Validation

use fugue::*;

fn k_fold_cross_validation<F>(
    x_data: Vec<f64>,
    y_data: Vec<f64>,
    k: usize,
    model_fn: F
) -> f64
where F: Fn(Vec<f64>, Vec<f64>) -> Model<(f64, f64, f64)>
{
    let n = x_data.len();
    let fold_size = n / k;
    let mut mse_scores = Vec::new();

    for fold in 0..k {
        let test_start = fold * fold_size;
        let test_end = if fold == k - 1 { n } else { (fold + 1) * fold_size };

        // Split data
        let mut train_x = Vec::new();
        let mut train_y = Vec::new();
        let mut test_x = Vec::new();
        let mut test_y = Vec::new();

        for i in 0..n {
            if i >= test_start && i < test_end {
                test_x.push(x_data[i]);
                test_y.push(y_data[i]);
            } else {
                train_x.push(x_data[i]);
                train_y.push(y_data[i]);
            }
        }

        // Train model (simplified - would run MCMC here)
        let mut rng = StdRng::seed_from_u64(fold as u64);
        let (params, _) = runtime::handler::run(
            PriorHandler { rng: &mut rng, trace: Trace::default() },
            model_fn(train_x, train_y)
        );

        // Predict on test set
        let predictions: Vec<f64> = test_x.iter()
            .map(|&x| params.0 + params.1 * x)
            .collect();

        // Compute MSE
        let mse = test_y.iter().zip(predictions.iter())
            .map(|(y, pred)| (y - pred).powi(2))
            .sum::<f64>() / test_y.len() as f64;

        mse_scores.push(mse);
    }

    mse_scores.iter().sum::<f64>() / k as f64
}

Real-World Applications

Economic Forecasting

// Example: GDP growth prediction
let gdp_model = prob!(
    // Macroeconomic predictors
    let beta_inflation <- sample(addr!("beta_inflation"), Normal::new(0.0, 2.0).unwrap());
    let beta_unemployment <- sample(addr!("beta_unemployment"), Normal::new(0.0, 2.0).unwrap());
    let beta_interest_rate <- sample(addr!("beta_interest_rate"), Normal::new(0.0, 2.0).unwrap());
    let intercept <- sample(addr!("intercept"), Normal::new(2.0, 1.0).unwrap()); // Prior: ~2% growth

    let sigma <- sample(addr!("sigma"), Gamma::new(2.0, 0.5).unwrap());

    // Quarterly GDP growth predictions
    for (i, (inflation, unemployment, interest_rate, gdp_growth)) in economic_data.iter().enumerate() {
        let expected_growth = intercept +
                            beta_inflation * inflation +
                            beta_unemployment * unemployment +
                            beta_interest_rate * interest_rate;

        let _obs <- observe(addr!("gdp", i), Normal::new(expected_growth, sigma).unwrap(), *gdp_growth);
    }

    pure((intercept, beta_inflation, beta_unemployment, beta_interest_rate))
);

Medical Research

// Example: Drug dose-response modeling
let dose_response_model = prob!(
    // Log-linear dose-response
    let log_ic50 <- sample(addr!("log_ic50"), Normal::new(0.0, 2.0).unwrap()); // IC50 concentration
    let hill_slope <- sample(addr!("hill_slope"), Normal::new(1.0, 0.5).unwrap()); // Cooperativity
    let baseline <- sample(addr!("baseline"), Normal::new(100.0, 10.0).unwrap()); // No drug effect
    let max_effect <- sample(addr!("max_effect"), Normal::new(0.0, 10.0).unwrap()); // Maximum inhibition

    let sigma <- sample(addr!("sigma"), Gamma::new(2.0, 0.5).unwrap());

    for (i, (log_dose, response)) in dose_response_data.iter().enumerate() {
        // Hill equation: E = baseline + (max_effect - baseline) / (1 + 10^(hill_slope * (log_ic50 - log_dose)))
        let hill_term = hill_slope * (log_ic50 - log_dose);
        let expected_response = baseline + (max_effect - baseline) / (1.0 + (10.0_f64).powf(hill_term));

        let _obs <- observe(addr!("response", i), Normal::new(expected_response, sigma).unwrap(), *response);
    }

    pure((log_ic50, hill_slope, baseline, max_effect))
);

Testing Your Understanding

Exercise 1: Multiple Regression

Extend basic linear regression to handle multiple predictors:

use fugue::*;

fn multiple_regression_model(
    x_data: Vec<Vec<f64>>, // Matrix: n observations ร— p predictors
    y_data: Vec<f64>
) -> Model<Vec<f64>> {
    let p = x_data[0].len(); // number of predictors

    prob!(
        // TODO: Implement multiple regression
        // - Create coefficient vector of length p
        // - Use matrix multiplication for linear predictor
        // - Add appropriate priors for high-dimensional case

        pure(vec![0.0; p]) // Placeholder
    )
}

Exercise 2: Heteroscedastic Regression

Model non-constant error variance:

use fugue::*;

fn heteroscedastic_model(
    x_data: Vec<f64>,
    y_data: Vec<f64>
) -> Model<(f64, f64, f64, f64)> {
    prob!(
        // TODO: Implement regression with non-constant variance
        // - Model log(ฯƒยฒ) as linear function of x
        // - ฯƒยฒแตข = exp(ฮณโ‚€ + ฮณโ‚ * xแตข)
        // - Use different variance for each observation

        pure((0.0, 0.0, 0.0, 0.0)) // Placeholder
    )
}

Exercise 3: Bayesian Variable Selection

Implement spike-and-slab priors for variable selection:

use fugue::*;

fn variable_selection_model(
    x_data: Vec<Vec<f64>>,
    y_data: Vec<f64>,
    inclusion_probability: f64
) -> Model<(Vec<f64>, Vec<bool>)> {
    let p = x_data[0].len();

    prob!(
        // TODO: Implement variable selection
        // - ฮณโฑผ ~ Bernoulli(ฯ€) for inclusion indicators  
        // - ฮฒโฑผ | ฮณโฑผ ~ ฮณโฑผ * Normal(0, ฯ„ยฒ) + (1-ฮณโฑผ) * ฮดโ‚€
        // - Spike-and-slab prior structure

        pure((vec![0.0; p], vec![false; p])) // Placeholder
    )
}

Key Takeaways

Linear Regression Mastery

  1. Bayesian Framework: Uncertainty quantification through posterior distributions
  2. Model Extensions: Robustness, nonlinearity, and regularization through prior specification
  3. Model Selection: Principled comparison using marginal likelihood and Bayes factors
  4. Scalability: Hierarchical models and efficient computation for high-dimensional problems
  5. Real-World Applications: Flexible framework adaptable to diverse scientific domains
  6. Production Ready: Cross-validation, diagnostics, and robust inference workflows

Core Techniques:

  • โœ… Basic Regression with uncertainty quantification
  • โœ… Robust Methods for outlier resistance
  • โœ… Polynomial Modeling for nonlinear relationships
  • โœ… Bayesian Model Selection for optimal complexity
  • โœ… Ridge Regression for high-dimensional problems
  • โœ… Hierarchical Extensions for grouped data
  • โœ… Production Deployment with diagnostics and validation

Linear regression in Fugue provides a solid foundation for more complex statistical models. The Bayesian approach naturally handles uncertainty, enables model comparison, and scales to modern high-dimensional problems through principled regularization.

Further Reading

Classification

A comprehensive guide to Bayesian classification using Fugue. This tutorial demonstrates how to build, analyze, and extend classification models for discrete outcomes, showcasing the power of probabilistic programming for uncertainty quantification and principled model selection.

Learning Objectives

By the end of this tutorial, you will understand:

  • Binary Classification: Logistic regression with posterior uncertainty for two-class problems
  • Multi-class Classification: Multinomial logistic models and one-vs-rest approaches
  • Hierarchical Classification: Group-level effects for nested data structures
  • Model Comparison: Bayesian information criteria and Bayes factors for model selection
  • Uncertainty Quantification: Extracting and interpreting prediction confidence intervals
  • Robust Methods: Constraint-aware MCMC for stable parameter estimation
  • Production Applications: Scalable classification workflows for real-world deployment

The Classification Framework

Classification problems involve predicting discrete outcomes from continuous or discrete inputs. In the Bayesian framework, we treat classification parameters as random variables with prior distributions, enabling natural uncertainty quantification and robust model comparison.

graph TB
    A["Labeled Data: (xโ‚,yโ‚), (xโ‚‚,yโ‚‚), ..., (xโ‚™,yโ‚™)"] --> B["Classification Model<br/>P(y|x, ฮธ)"]

    B --> C["Bayesian Framework"]
    C --> D["Prior: p(ฮธ)"]
    C --> E["Likelihood: p(y|X, ฮธ)"]

    D --> F["Posterior: p(ฮธ|y, X)"]
    E --> F

    F --> G["MCMC Sampling"]
    G --> H["Parameter Uncertainty"]
    G --> I["Prediction Probabilities"]
    G --> J["Model Comparison"]

Advantages of Bayesian Classification

Traditional machine learning gives you point predictions. Bayesian classification provides:

  • Posterior probability distributions over class labels
  • Uncertainty estimates for each prediction
  • Principled model comparison using marginal likelihoods
  • Automatic regularization through informative priors

Binary Classification: Logistic Regression

The foundation of Bayesian classification is logistic regression, which models the probability of binary outcomes.

Mathematical Model

For binary classification, we model:

Where:

  • is the binary outcome
  • is the probability of class 1
  • is the log-odds
  • are the regression coefficients

Implementation

// Basic Bayesian logistic regression model
fn logistic_regression_model(features: Vec<Vec<f64>>, labels: Vec<bool>) -> Model<Vec<f64>> {
    let n_features = features[0].len();

    prob! {
        // Sample coefficients with regularizing priors - build using plate
        let coefficients <- plate!(i in 0..n_features => {
            sample(addr!("beta", i), fugue::Normal::new(0.0, 2.0).unwrap())
        });

        // Clone coefficients for use in closure
        let coefficients_for_obs = coefficients.clone();
        let _observations <- plate!(obs_idx in features.iter().zip(labels.iter()).enumerate() => {
            let (idx, (x_vec, &y)) = obs_idx;
            // Compute linear predictor (log-odds)
            let mut linear_pred = 0.0;
            for (coef, &x_val) in coefficients_for_obs.iter().zip(x_vec.iter()) {
                linear_pred += coef * x_val;
            }

            // Convert to probability using logistic function
            let prob = 1.0 / (1.0 + { -linear_pred }.exp());

            // Ensure probability is in valid range
            let bounded_prob = prob.clamp(1e-10, 1.0 - 1e-10);

            // Observe the binary outcome
            observe(addr!("y", idx), Bernoulli::new(bounded_prob).unwrap(), y)
        });

        pure(coefficients)
    }
}

fn binary_classification_demo() {
    println!("=== Binary Classification with Logistic Regression ===\n");

    // Generate synthetic data
    let (features, labels) = generate_classification_data(100, 42);
    let positive_cases = labels.iter().filter(|&&x| x).count();

    println!("๐Ÿ“Š Generated {} data points", features.len());
    println!("   - Features: {} dimensions", features[0].len());
    println!(
        "   - Positive cases: {} / {} ({:.1}%)",
        positive_cases,
        labels.len(),
        100.0 * positive_cases as f64 / labels.len() as f64
    );
    println!("   - True coefficients: intercept=-1.0, ฮฒโ‚=2.0, ฮฒโ‚‚=-1.5");

    // Run MCMC inference
    let model_fn = move || logistic_regression_model(features.clone(), labels.clone());
    let mut rng = StdRng::seed_from_u64(12345);

    println!("\n๐Ÿ”ฌ Running MCMC inference...");
    let samples = adaptive_mcmc_chain(&mut rng, model_fn, 800, 200);

    // Extract coefficient estimates
    let valid_samples: Vec<_> = samples
        .iter()
        .filter_map(|(coeffs, trace)| {
            if trace.total_log_weight().is_finite() {
                Some(coeffs)
            } else {
                None
            }
        })
        .collect();

    if !valid_samples.is_empty() {
        println!(
            "โœ… MCMC completed with {} valid samples",
            valid_samples.len()
        );
        println!("\n๐Ÿ“ˆ Coefficient Estimates:");

        let coef_names = ["Intercept", "ฮฒโ‚ (feature 1)", "ฮฒโ‚‚ (feature 2)"];
        let true_coefs = [-1.0, 2.0, -1.5];

        for (i, (name, true_val)) in coef_names.iter().zip(true_coefs.iter()).enumerate() {
            let coef_samples: Vec<f64> = valid_samples.iter().map(|coeffs| coeffs[i]).collect();

            let mean_coef = coef_samples.iter().sum::<f64>() / coef_samples.len() as f64;
            let std_coef = {
                let variance = coef_samples
                    .iter()
                    .map(|c| (c - mean_coef).powi(2))
                    .sum::<f64>()
                    / (coef_samples.len() - 1) as f64;
                variance.sqrt()
            };

            println!(
                "   - {}: {:.3} ยฑ {:.3} (true: {:.1})",
                name, mean_coef, std_coef, true_val
            );
        }

        // Model diagnostics
        let avg_log_weight = samples
            .iter()
            .map(|(_, trace)| trace.total_log_weight())
            .filter(|w| w.is_finite())
            .sum::<f64>()
            / valid_samples.len() as f64;

        println!("   - Average log-likelihood: {:.2}", avg_log_weight);

        // Make predictions on new data
        println!("\n๐Ÿ”ฎ Prediction Example:");
        let test_features = [1.0, 0.5, -0.8]; // New observation
        let mut predicted_probs = Vec::new();

        for coeffs in valid_samples.iter().take(50) {
            // Use subset for speed
            let mut linear_pred = 0.0;
            for (coef, &x_val) in coeffs.iter().zip(test_features.iter()) {
                linear_pred += coef * x_val;
            }
            let prob = 1.0 / (1.0 + (-linear_pred).exp());
            predicted_probs.push(prob);
        }

        let mean_prob = predicted_probs.iter().sum::<f64>() / predicted_probs.len() as f64;
        let std_prob = {
            let variance = predicted_probs
                .iter()
                .map(|p| (p - mean_prob).powi(2))
                .sum::<f64>()
                / (predicted_probs.len() - 1) as f64;
            variance.sqrt()
        };

        println!(
            "   - Test point [0.5, -0.8]: P(y=1) = {:.3} ยฑ {:.3}",
            mean_prob, std_prob
        );
        if mean_prob > 0.5 {
            println!("   - Prediction: Class 1 (probability > 0.5)");
        } else {
            println!("   - Prediction: Class 0 (probability < 0.5)");
        }
    } else {
        println!("โŒ No valid MCMC samples obtained");
    }

    println!();
}

Key Features

  • Automatic constraint handling: Our improved MCMC properly handles the logistic transformation
  • Interpretable coefficients: Each represents log-odds ratios
  • Natural uncertainty: Posterior samples give prediction intervals

Logistic Regression Interpretation

  • Coefficient : feature increases log-odds of class 1
  • Coefficient : feature decreases log-odds of class 1
  • gives the odds ratio for a unit change in
  • Use standardized features for coefficient comparability

Multi-class Classification: Multinomial Logit

For problems with more than two classes, we use multinomial logistic regression.

Mathematical Model

For classes, we model:

The last class () serves as the reference category.

Implementation

// Multinomial logistic regression for multi-class classification
// Note: This is a simplified version - full multinomial requires more complex implementation
fn multiclass_classification_demo() {
    println!("=== Multi-class Classification (Conceptual) ===\n");

    let (features, labels) = generate_multiclass_data(150, 3, 1337);

    println!("๐Ÿ“Š Generated {} data points", features.len());
    println!("   - {} classes", 3);
    println!("   - Features: {} dimensions", features[0].len());

    // Count class distribution
    let mut class_counts = [0; 3];
    for &label in &labels {
        class_counts[label] += 1;
    }

    for (class_id, count) in class_counts.iter().enumerate() {
        println!(
            "   - Class {}: {} samples ({:.1}%)",
            class_id,
            count,
            100.0 * *count as f64 / labels.len() as f64
        );
    }

    println!("\n๐Ÿ’ก Multinomial Classification Concepts:");
    println!("   - Uses K-1 sets of coefficients (reference category approach)");
    println!("   - Each coefficient set models log(P(class_k) / P(class_reference))");
    println!("   - Probabilities sum to 1 via softmax transformation");
    println!("   - More complex to implement but follows same Bayesian principles");

    // For now, demonstrate the concept with binary classification on each class
    println!("\n๐Ÿ”ฌ One-vs-Rest Classification (simplified approach):");

    for target_class in 0..3 {
        // Convert to binary problem: target_class vs. all others
        let binary_labels: Vec<bool> = labels.iter().map(|&label| label == target_class).collect();

        let positive_cases = binary_labels.iter().filter(|&&x| x).count();

        println!("\n   Class {} vs Rest:", target_class);
        println!(
            "   - Positive cases: {} / {}",
            positive_cases,
            binary_labels.len()
        );

        // Clone data for each iteration to avoid move issues
        let features_copy = features.clone();
        let model_fn =
            move || logistic_regression_model(features_copy.clone(), binary_labels.clone());
        let mut rng = StdRng::seed_from_u64(1000 + target_class as u64);

        let samples = adaptive_mcmc_chain(&mut rng, model_fn, 300, 60);
        let valid_samples = samples.len();

        if valid_samples > 0 {
            println!("   - MCMC: {} samples obtained", valid_samples);
        }
    }

    println!("\n๐Ÿ’ญ Note: Full multinomial logistic regression requires implementing");
    println!("   the softmax link function and careful handling of identifiability constraints.");
    println!();
}

Hierarchical Classification

When your data has group structure (e.g., students within schools, patients within hospitals), hierarchical models can improve predictions by sharing information across groups.

Mathematical Model

Where:

  • indexes individuals, indexes groups
  • are group-specific intercepts
  • control how much groups can vary

Implementation

// Hierarchical logistic regression with group-level effects
fn hierarchical_classification_model(
    features: Vec<Vec<f64>>,
    labels: Vec<bool>,
    groups: Vec<usize>,
) -> Model<(f64, f64, Vec<f64>)> {
    let n_groups = groups.iter().max().unwrap_or(&0) + 1;

    prob! {
        // Global parameters
        let global_intercept <- sample(addr!("global_intercept"), fugue::Normal::new(0.0, 2.0).unwrap());
        let slope <- sample(addr!("slope"), fugue::Normal::new(0.0, 2.0).unwrap());

        // Group-level variance
        let group_sigma <- sample(addr!("group_sigma"), Gamma::new(1.0, 1.0).unwrap());

        // Group-specific intercepts using plate notation
        let group_intercepts <- plate!(g in 0..n_groups => {
            sample(addr!("group_intercept", g), fugue::Normal::new(global_intercept, group_sigma).unwrap())
        });

        // Clone group_intercepts for use in closure
        let group_intercepts_for_obs = group_intercepts.clone();
        let _observations <- plate!(data in features.iter()
            .map(|f| f[1]) // Extract the single feature (after intercept)
            .zip(labels.iter())
            .zip(groups.iter())
            .enumerate() => {
            let (obs_idx, ((x_val, &y), &group_id)) = data;
            let linear_pred = group_intercepts_for_obs[group_id] + slope * x_val;
            let prob = 1.0 / (1.0 + { -linear_pred }.exp());
            let bounded_prob = prob.clamp(1e-10, 1.0 - 1e-10);

            observe(addr!("obs", obs_idx), Bernoulli::new(bounded_prob).unwrap(), y)
        });

        pure((global_intercept, slope, group_intercepts))
    }
}

fn hierarchical_classification_demo() {
    println!("=== Hierarchical Classification ===\n");

    let (features, labels, groups) = generate_hierarchical_data(4, 25, 5678);
    let n_groups = groups.iter().max().unwrap() + 1;

    println!("๐Ÿ“Š Generated hierarchical data:");
    println!(
        "   - {} groups with {} observations each",
        n_groups,
        features.len() / n_groups
    );
    println!("   - Total: {} data points", features.len());

    // Show group-wise statistics
    for group_id in 0..n_groups {
        let group_labels: Vec<bool> = groups
            .iter()
            .zip(labels.iter())
            .filter_map(|(&g, &y)| if g == group_id { Some(y) } else { None })
            .collect();

        let positive_rate =
            group_labels.iter().filter(|&&x| x).count() as f64 / group_labels.len() as f64;
        println!(
            "   - Group {}: {:.1}% positive cases",
            group_id,
            positive_rate * 100.0
        );
    }

    println!("\n๐Ÿ”ฌ Running hierarchical MCMC...");
    let model_fn =
        move || hierarchical_classification_model(features.clone(), labels.clone(), groups.clone());
    let mut rng = StdRng::seed_from_u64(9999);
    let samples = adaptive_mcmc_chain(&mut rng, model_fn, 600, 150);

    let valid_samples: Vec<_> = samples
        .iter()
        .filter(|(_, trace)| trace.total_log_weight().is_finite())
        .collect();

    if !valid_samples.is_empty() {
        println!(
            "โœ… Hierarchical MCMC completed with {} valid samples",
            valid_samples.len()
        );

        // Extract global parameters
        let global_intercepts: Vec<f64> =
            valid_samples.iter().map(|(params, _)| params.0).collect();
        let slopes: Vec<f64> = valid_samples.iter().map(|(params, _)| params.1).collect();

        let mean_global_int =
            global_intercepts.iter().sum::<f64>() / global_intercepts.len() as f64;
        let mean_slope = slopes.iter().sum::<f64>() / slopes.len() as f64;

        println!("\n๐Ÿ“ˆ Global Parameter Estimates:");
        println!("   - Global intercept: {:.3} (true: ~0.0)", mean_global_int);
        println!("   - Slope: {:.3} (true: 1.5)", mean_slope);

        // Extract group-specific intercepts
        println!("\n๐Ÿ˜๏ธ  Group-Specific Intercepts:");
        for group_id in 0..n_groups {
            let group_intercepts: Vec<f64> = valid_samples
                .iter()
                .map(|(params, _)| params.2[group_id])
                .collect();

            let mean_group_int =
                group_intercepts.iter().sum::<f64>() / group_intercepts.len() as f64;
            println!("   - Group {}: {:.3}", group_id, mean_group_int);
        }

        println!("\n๐Ÿ’ก Hierarchical Benefits:");
        println!("   - Groups share information through global parameters");
        println!("   - Individual groups can have their own intercepts");
        println!("   - Better predictions for groups with less data");
        println!("   - Automatic regularization through group-level priors");
    } else {
        println!("โŒ No valid hierarchical samples obtained");
    }

    println!();
}

Model Comparison and Selection

Bayesian methods provide principled approaches to comparing models:

Deviance Information Criterion (DIC)

DIC balances model fit against complexity:

Where is average deviance and is effective parameters.

Widely Applicable Information Criterion (WAIC)

WAIC is a more robust alternative:

Implementation

// Simple model comparison using log-likelihood
fn model_comparison_demo() {
    println!("=== Model Comparison ===\n");

    let (features, labels) = generate_classification_data(80, 2021);
    let _features_ref = &features;
    let _labels_ref = &labels;

    println!("๐Ÿ“Š Comparing different logistic regression models:");
    println!("   - Model 1: Intercept only");
    println!("   - Model 2: Intercept + Feature 1");
    println!("   - Model 3: Full model (Intercept + Feature 1 + Feature 2)");

    struct ModelResult {
        name: String,
        n_params: usize,
        log_likelihood: f64,
        samples: usize,
    }

    let mut results = Vec::new();

    // Model 1: Intercept only
    {
        let intercept_features: Vec<Vec<f64>> = features
            .iter()
            .map(|f| vec![f[0]]) // Just intercept
            .collect();
        let labels_clone = labels.clone();

        let model_fn =
            move || logistic_regression_model(intercept_features.clone(), labels_clone.clone());
        let mut rng = StdRng::seed_from_u64(1111);
        let samples = adaptive_mcmc_chain(&mut rng, model_fn, 300, 80);

        let valid_samples: Vec<_> = samples
            .iter()
            .filter(|(_, trace)| trace.total_log_weight().is_finite())
            .collect();

        if !valid_samples.is_empty() {
            let avg_log_lik = valid_samples
                .iter()
                .map(|(_, trace)| trace.total_log_weight())
                .sum::<f64>()
                / valid_samples.len() as f64;

            results.push(ModelResult {
                name: "Intercept only".to_string(),
                n_params: 1,
                log_likelihood: avg_log_lik,
                samples: valid_samples.len(),
            });
        }
    }

    // Model 2: Intercept + Feature 1
    {
        let reduced_features: Vec<Vec<f64>> = features
            .iter()
            .map(|f| vec![f[0], f[1]]) // Intercept + first feature
            .collect();
        let labels_clone = labels.clone();

        let model_fn =
            move || logistic_regression_model(reduced_features.clone(), labels_clone.clone());
        let mut rng = StdRng::seed_from_u64(2222);
        let samples = adaptive_mcmc_chain(&mut rng, model_fn, 300, 80);

        let valid_samples: Vec<_> = samples
            .iter()
            .filter(|(_, trace)| trace.total_log_weight().is_finite())
            .collect();

        if !valid_samples.is_empty() {
            let avg_log_lik = valid_samples
                .iter()
                .map(|(_, trace)| trace.total_log_weight())
                .sum::<f64>()
                / valid_samples.len() as f64;

            results.push(ModelResult {
                name: "Intercept + Feature 1".to_string(),
                n_params: 2,
                log_likelihood: avg_log_lik,
                samples: valid_samples.len(),
            });
        }
    }

    // Model 3: Full model
    {
        let labels_clone = labels.clone();
        let model_fn = move || logistic_regression_model(features.clone(), labels_clone.clone());
        let mut rng = StdRng::seed_from_u64(3333);
        let samples = adaptive_mcmc_chain(&mut rng, model_fn, 300, 80);

        let valid_samples: Vec<_> = samples
            .iter()
            .filter(|(_, trace)| trace.total_log_weight().is_finite())
            .collect();

        if !valid_samples.is_empty() {
            let avg_log_lik = valid_samples
                .iter()
                .map(|(_, trace)| trace.total_log_weight())
                .sum::<f64>()
                / valid_samples.len() as f64;

            results.push(ModelResult {
                name: "Full model".to_string(),
                n_params: 3,
                log_likelihood: avg_log_lik,
                samples: valid_samples.len(),
            });
        }
    }

    if !results.is_empty() {
        println!("\n๐Ÿ† Model Comparison Results:");
        println!("   Model                    | Params | Log-Likelihood | Samples");
        println!("   -------------------------|--------|----------------|--------");

        for result in &results {
            println!(
                "   {:24} | {:6} | {:14.2} | {:7}",
                result.name, result.n_params, result.log_likelihood, result.samples
            );
        }

        // Find best model
        if let Some(best) = results
            .iter()
            .max_by(|a, b| a.log_likelihood.partial_cmp(&b.log_likelihood).unwrap())
        {
            println!("\n๐Ÿฅ‡ Best Model: {} (highest log-likelihood)", best.name);
        }

        println!("\n๐Ÿ’ก Model Selection Notes:");
        println!("   - Higher log-likelihood indicates better fit to data");
        println!("   - In practice, use information criteria (AIC, BIC, WAIC)");
        println!("   - These account for model complexity to prevent overfitting");
        println!("   - Cross-validation provides robust model comparison");
    } else {
        println!("โŒ Model comparison failed - no valid samples obtained");
    }

    println!();
}

Practical Considerations

Feature Engineering

Effective classification often requires thoughtful feature engineering:

let x1 = 0.5; let x2 = 0.8; let category = "A";
// Polynomial features
let x2_squared = x1 * x1;
let x1_x2_interaction = x1 * x2;

// Categorical encoding (one-hot)
let is_category_a = if category == "A" { 1.0 } else { 0.0 };

Handling Class Imbalance

For imbalanced datasets, consider:

  • Weighted priors: Give more weight to rare classes
  • Threshold tuning: Optimize classification thresholds
  • Stratified sampling: Ensure balanced training data

Computational Considerations

  • Start simple: Begin with basic logistic regression
  • Check convergence: Monitor R-hat and effective sample size
  • Scale features: Standardize continuous predictors
  • Use constraints: Let Fugue's constraint-aware MCMC handle bounded parameters

MCMC for Classification

Classification models can be challenging for MCMC due to:

  • Separation: Perfect classification can lead to infinite parameter estimates
  • Weak identification: Sparse data in some classes affects convergence
  • Constraint handling: Probabilities must sum to 1 in multinomial models

Use regularizing priors and check diagnostics carefully.

Performance Evaluation

Metrics for Binary Classification

let tp = 10.0; let tn = 20.0; let fp = 5.0; let fn_count = 3.0;
// Accuracy, Precision, Recall, F1-score
let accuracy = (tp + tn) / (tp + tn + fp + fn_count);
let precision = tp / (tp + fp);
let recall = tp / (tp + fn_count);
let f1 = 2.0 * precision * recall / (precision + recall);

Bayesian Evaluation

Unlike traditional ML, Bayesian methods naturally provide:

  • Credible intervals for all metrics
  • Prediction intervals for new observations
  • Model uncertainty via posterior model probabilities

Advanced Extensions

Ordinal Classification

For ordered categorical outcomes (e.g., ratings, severity levels):

use fugue::*;

// Ordinal logistic regression with proportional odds
fn ordinal_classification_model(
    features: Vec<Vec<f64>>,
    outcomes: Vec<usize>, // 0, 1, 2, ..., K-1
    n_categories: usize
) -> Model<(Vec<f64>, Vec<f64>)> {
    prob! {
        // Regression coefficients (shared across categories)
        let coefficients <- plate!(i in 0..features[0].len() => {
            sample(addr!("beta", i), fugue::Normal::new(0.0, 2.0).unwrap())
        });
        
        // Cutpoints (must be ordered)
        let mut cutpoints = Vec::new();
        let first_cut <- sample(addr!("cutpoint", 0), fugue::Normal::new(0.0, 5.0).unwrap());
        cutpoints.push(first_cut);
        
        for k in 1..(n_categories-1) {
            let delta <- sample(addr!("delta", k), Gamma::new(1.0, 1.0).unwrap());
            cutpoints.push(cutpoints[k-1] + delta);
        }
        
        // Likelihood using cumulative logits
        let _observations <- plate!(obs_idx in features.iter().zip(outcomes.iter()).enumerate() => {
            let (idx, (x_vec, &y)) = obs_idx;
            let mut linear_pred = 0.0;
            for (coef, &x_val) in coefficients.iter().zip(x_vec.iter()) {
                linear_pred += coef * x_val;
            }
            
            // Compute category probabilities
            let mut probs = Vec::new();
            for k in 0..n_categories {
                let prob = if k == 0 {
                    1.0 / (1.0 + (-(cutpoints[0] - linear_pred)).exp())
                } else if k == n_categories - 1 {
                    1.0 - (1.0 / (1.0 + (-(cutpoints[k-1] - linear_pred)).exp()))
                } else {
                    let p_le_k = 1.0 / (1.0 + (-(cutpoints[k] - linear_pred)).exp());
                    let p_le_k_minus_1 = 1.0 / (1.0 + (-(cutpoints[k-1] - linear_pred)).exp());
                    p_le_k - p_le_k_minus_1
                };
                probs.push(prob.max(1e-10).min(1.0 - 1e-10));
            }
            
            observe(addr!("y", idx), Categorical::new(probs).unwrap(), y)
        });
        
        pure((coefficients, cutpoints))
    }
}

Robust Classification

Handle outliers using heavy-tailed link functions:

use fugue::*;

// Robust logistic regression with t-distributed errors
fn robust_classification_model(
    features: Vec<Vec<f64>>,
    labels: Vec<bool>
) -> Model<(Vec<f64>, f64)> {
    prob! {
        // Coefficients
        let coefficients <- plate!(i in 0..features[0].len() => {
            sample(addr!("beta", i), fugue::Normal::new(0.0, 2.0).unwrap())
        });
        
        // Degrees of freedom for robustness
        let nu <- sample(addr!("nu"), Gamma::new(2.0, 0.1).unwrap());
        
        // Robust likelihood using latent variables
        let _observations <- plate!(obs_idx in features.iter().zip(labels.iter()).enumerate() => {
            let (idx, (x_vec, &y)) = obs_idx;
            
            // Linear predictor
            let mut eta = 0.0;
            for (coef, &x_val) in coefficients.iter().zip(x_vec.iter()) {
                eta += coef * x_val;
            }
            
            // Latent variable for robustness
            let z <- sample(addr!("z", idx), fugue::Normal::new(eta, 1.0).unwrap());
            
            // Robust transformation
            let p = 1.0 / (1.0 + (-z).exp());
            let bounded_p = p.max(1e-10).min(1.0 - 1e-10);
            
            observe(addr!("y", idx), Bernoulli::new(bounded_p).unwrap(), y)
        });
        
        pure((coefficients, nu))
    }
}

Production Considerations

Scalability

For large datasets, consider:

  1. Mini-batch MCMC: Process data in chunks for memory efficiency
  2. Variational Inference: Approximate posteriors for faster computation
  3. Sparse Models: Use regularization for high-dimensional feature spaces
  4. GPU Acceleration: Vectorized operations for matrix computations

Production Deployment

  • Monitor convergence: Set up automated R-hat checking
  • Prediction pipelines: Cache MCMC samples for fast inference
  • Model updating: Implement online learning for streaming data
  • A/B testing: Use Bayesian methods for experiment analysis

Model Diagnostics

Essential checks for classification models:

use fugue::inference::diagnostics::*;

fn classification_diagnostics(
    samples: &[Vec<f64>], 
    features: &[Vec<f64>], 
    labels: &[bool]
) {
    // Compute prediction accuracy
    let predictions: Vec<bool> = features.iter().enumerate().map(|(i, x_vec)| {
        let prob: f64 = samples.iter().map(|coeffs| {
            let linear_pred = coeffs.iter().zip(x_vec.iter())
                .map(|(coef, x)| coef * x).sum::<f64>();
            1.0 / (1.0 + (-linear_pred).exp())
        }).sum::<f64>() / samples.len() as f64;
        
        prob > 0.5
    }).collect();
    
    // Classification metrics
    let tp = predictions.iter().zip(labels.iter())
        .filter(|(&pred, &actual)| pred && actual).count();
    let tn = predictions.iter().zip(labels.iter())
        .filter(|(&pred, &actual)| !pred && !actual).count();
    let fp = predictions.iter().zip(labels.iter())
        .filter(|(&pred, &actual)| pred && !actual).count();
    let fn_ = predictions.iter().zip(labels.iter())
        .filter(|(&pred, &actual)| !pred && actual).count();
    
    let accuracy = (tp + tn) as f64 / labels.len() as f64;
    let precision = tp as f64 / (tp + fp) as f64;
    let recall = tp as f64 / (tp + fn_) as f64;
    
    println!("Classification Diagnostics:");
    println!("  Accuracy: {:.3}", accuracy);
    println!("  Precision: {:.3}", precision);
    println!("  Recall: {:.3}", recall);
    println!("  F1-Score: {:.3}", 2.0 * precision * recall / (precision + recall));
}

Running the Examples

To explore these classification techniques:

# Run the classification demonstrations
cargo run --example classification

# Run specific tests
cargo test --example classification

# Build documentation with examples
mdbook build docs/

Key Takeaways

Classification Mastery

  1. Bayesian Advantage: Natural uncertainty quantification through posterior distributions
  2. Model Flexibility: Handle binary, multi-class, ordinal, and hierarchical outcomes
  3. Robust Methods: Constraint-aware MCMC prevents numerical issues
  4. Principled Selection: Use information criteria and Bayes factors for model choice
  5. Production Ready: Scalable workflows with proper diagnostics and validation
  6. Real-World Applications: Flexible framework for diverse classification problems

Core Techniques:

  • โœ… Binary Classification with logistic regression and uncertainty
  • โœ… Multi-class Methods using multinomial and one-vs-rest approaches
  • โœ… Hierarchical Models for grouped and nested data structures
  • โœ… Model Comparison with information criteria and Bayes factors
  • โœ… Robust Extensions for outlier resistance and stability
  • โœ… Production Deployment with monitoring and scalable inference

Further Reading

  • Building Complex Models - Advanced modeling techniques
  • Optimizing Performance - Scalable inference strategies
  • Hierarchical Models - Advanced multilevel modeling
  • Mixture Models - Unsupervised classification and clustering
  • Time Series - Classification with temporal structure
  • Gelman et al. "Bayesian Data Analysis" - Comprehensive statistical reference
  • McElreath "Statistical Rethinking" - Modern Bayesian approach
  • Kruschke "Doing Bayesian Data Analysis" - Applied Bayesian methods

The combination of Fugue's type-safe probabilistic programming and constraint-aware MCMC makes Bayesian classification both theoretically principled and computationally practical. The natural uncertainty quantification provides insights that traditional point estimates cannot match.

Mixture Models

A comprehensive guide to Bayesian mixture modeling using Fugue. This tutorial demonstrates how to build, analyze, and extend mixture models for complex data structures, showcasing advanced probabilistic programming techniques for unsupervised learning and heterogeneous populations.

Learning Objectives

By the end of this tutorial, you will understand:

  • Gaussian Mixture Models: Foundation of mixture modeling for continuous data
  • Latent Variable Inference: MCMC techniques for unobserved cluster assignments
  • Model Selection: Choosing the optimal number of mixture components
  • Mixture of Experts: Supervised mixture models for complex decision boundaries
  • Infinite Mixtures: Dirichlet Process models for automatic component discovery
  • Temporal Mixtures: Hidden Markov Models and dynamic clustering
  • Advanced Diagnostics: Convergence assessment and cluster validation

The Mixture Modeling Framework

Mixture models assume that observed data arise from a mixture of underlying populations, each governed by its own distribution. This framework naturally handles heterogeneous data where simple single-distribution models fail.

graph TB
    A["Heterogeneous Data<br/>Multiple Populations"] --> B["Mixture Model<br/>โˆ‘ ฯ€โ‚– f(x|ฮธโ‚–)"]
    
    B --> C["Components"]
    C --> D["Component 1<br/>ฯ€โ‚, ฮธโ‚"]
    C --> E["Component 2<br/>ฯ€โ‚‚, ฮธโ‚‚"] 
    C --> F["Component K<br/>ฯ€โ‚–, ฮธโ‚–"]
    
    G["Latent Variables"] --> H["Cluster Assignments<br/>zแตข โˆˆ {1,...,K}"]
    G --> I["Mixing Weights<br/>ฯ€ = (ฯ€โ‚,...,ฯ€โ‚–)"]
    
    B --> J["Bayesian Inference"]
    J --> K["MCMC Sampling"]
    K --> L["Posterior Distributions"]
    K --> M["Cluster Predictions"]
    K --> N["Model Comparison"]
    
    style B fill:#ccffcc
    style L fill:#e1f5fe
    style M fill:#e1f5fe
    style N fill:#e1f5fe

Mathematical Foundation

Basic Mixture Model

For components, the mixture density is:

Where:

  • are mixing weights with
  • is the component density for cluster
  • are component-specific parameters

Latent Variable Formulation

Introduce latent cluster assignments :

This data augmentation approach enables efficient MCMC inference.

Mixture Model Advantages

  • Flexibility: Model complex, multimodal distributions
  • Interpretability: Each component represents a meaningful subpopulation
  • Uncertainty: Natural clustering with prediction confidence
  • Extensibility: Easy to incorporate covariates and hierarchical structure

Gaussian Mixture Models

The most common mixture model uses Gaussian components, ideal for continuous data clustering.

Mathematical Model

For Gaussian components:

Priors:

Implementation

// Simple 2-component Gaussian mixture model
fn gaussian_mixture_model(data: Vec<f64>) -> Model<(f64, f64, f64, f64, f64)> {
    prob! {
        // Mixing weight for first component
        let pi1 <- sample(addr!("pi1"), fugue::Beta::new(1.0, 1.0).unwrap());

        // Component 1 parameters
        let mu1 <- sample(addr!("mu1"), fugue::Normal::new(0.0, 5.0).unwrap());
        let sigma1 <- sample(addr!("sigma1"), Gamma::new(1.0, 1.0).unwrap());

        // Component 2 parameters
        let mu2 <- sample(addr!("mu2"), fugue::Normal::new(0.0, 5.0).unwrap());
        let sigma2 <- sample(addr!("sigma2"), Gamma::new(1.0, 1.0).unwrap());

        // Observations
        let _observations <- plate!(i in 0..data.len() => {
            // Ensure valid probabilities
            let p1 = pi1.clamp(0.001, 0.999); // Clamp to valid range
            let weights = vec![p1, 1.0 - p1];
            let x = data[i];

            sample(addr!("z", i), Categorical::new(weights).unwrap())
                .bind(move |z_i| {
                    // Explicitly handle only 2 components
                    let (mu_i, sigma_i) = if z_i == 0 {
                        (mu1, sigma1)
                    } else {
                        (mu2, sigma2)
                    };
                    observe(addr!("x", i), fugue::Normal::new(mu_i, sigma_i).unwrap(), x)
                })
        });

        pure((pi1, mu1, sigma1, mu2, sigma2))
    }
}

fn gaussian_mixture_demo() {
    println!("=== Gaussian Mixture Model ===\n");

    // Generate synthetic mixture data: 2 components
    let true_components = vec![
        (0.6, -1.5, 0.8), // 60% weight, mean=-1.5, std=0.8
        (0.4, 2.0, 1.2),  // 40% weight, mean=2.0, std=1.2
    ];
    let (data, true_labels) = generate_mixture_data(80, &true_components, 42);

    println!(
        "๐Ÿ“Š Generated {} data points from {} true components",
        data.len(),
        true_components.len()
    );
    for (i, (weight, mu, sigma)) in true_components.iter().enumerate() {
        println!(
            "   - Component {}: ฯ€={:.1}, ฮผ={:.1}, ฯƒ={:.1}",
            i + 1,
            weight,
            mu,
            sigma
        );
    }

    let n_true_labels: Vec<usize> = (0..true_components.len())
        .map(|k| true_labels.iter().filter(|&&label| label == k).count())
        .collect();

    for (k, count) in n_true_labels.iter().enumerate() {
        println!(
            "   - True cluster {}: {} observations ({:.1}%)",
            k + 1,
            count,
            100.0 * *count as f64 / data.len() as f64
        );
    }

    // Fit mixture model
    println!("\n๐Ÿ”ฌ Fitting 2-component Gaussian mixture model...");
    let model_fn = move || gaussian_mixture_model(data.clone());
    let mut rng = StdRng::seed_from_u64(123);
    let samples = adaptive_mcmc_chain(&mut rng, model_fn, 600, 150);

    let valid_samples: Vec<_> = samples
        .iter()
        .filter(|(_, trace)| trace.total_log_weight().is_finite())
        .collect();

    if !valid_samples.is_empty() {
        println!(
            "โœ… MCMC completed with {} valid samples",
            valid_samples.len()
        );

        // Extract parameter estimates
        println!("\n๐Ÿ“ˆ Estimated Parameters:");

        let pi1_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.0).collect();
        let mu1_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.1).collect();
        let sigma1_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.2).collect();
        let mu2_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.3).collect();
        let sigma2_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.4).collect();

        let mean_pi1 = pi1_samples.iter().sum::<f64>() / pi1_samples.len() as f64;
        let mean_mu1 = mu1_samples.iter().sum::<f64>() / mu1_samples.len() as f64;
        let mean_sigma1 = sigma1_samples.iter().sum::<f64>() / sigma1_samples.len() as f64;
        let mean_mu2 = mu2_samples.iter().sum::<f64>() / mu2_samples.len() as f64;
        let mean_sigma2 = sigma2_samples.iter().sum::<f64>() / sigma2_samples.len() as f64;

        println!(
            "   - Component 1: ฯ€ฬ‚={:.2}, ฮผฬ‚={:.1}, ฯƒฬ‚={:.1}",
            mean_pi1, mean_mu1, mean_sigma1
        );
        println!(
            "   - Component 2: ฯ€ฬ‚={:.2}, ฮผฬ‚={:.1}, ฯƒฬ‚={:.1}",
            1.0 - mean_pi1,
            mean_mu2,
            mean_sigma2
        );

        println!("\n๐ŸŽฏ Parameter Recovery:");
        let (true_w1, true_mu1, true_sigma1) = true_components[0];
        let (true_w2, true_mu2, true_sigma2) = true_components[1];
        println!("   - Component 1: ฯ€ true={:.1} est={:.2}, ฮผ true={:.1} est={:.1}, ฯƒ true={:.1} est={:.1}",
                true_w1, mean_pi1, true_mu1, mean_mu1, true_sigma1, mean_sigma1);
        println!("   - Component 2: ฯ€ true={:.1} est={:.2}, ฮผ true={:.1} est={:.1}, ฯƒ true={:.1} est={:.1}",
                true_w2, 1.0 - mean_pi1, true_mu2, mean_mu2, true_sigma2, mean_sigma2);
    } else {
        println!("โŒ No valid MCMC samples obtained");
    }

    println!();
}

Key Features

  • Automatic clustering: Soft cluster assignments via posterior probabilities
  • Uncertainty quantification: Credible intervals for all parameters
  • Model comparison: Bayesian model selection for optimal

Label Switching in Mixtures

Mixture components are not identifiable due to label switching - permuting component labels gives the same likelihood. This causes:

  • Multimodal posteriors: Multiple equivalent parameter configurations
  • MCMC convergence issues: Chains can jump between label permutations

Solutions: Use informative priors, post-process with label matching, or employ specialized algorithms like the allocation sampler.

Multivariate Mixtures

Extend to multivariate data with full covariance structure.

Mathematical Model

For -dimensional data:

Matrix-Normal Inverse-Wishart Priors:

Implementation

// Simple 2-component multivariate Gaussian mixture (2D)
#[allow(clippy::type_complexity)] // Complex tuple needed for demonstration
fn multivariate_mixture_model(
    data: Vec<Vec<f64>>,
) -> Model<(f64, f64, f64, f64, f64, f64, f64, f64)> {
    prob! {
        // Mixing weight
        let pi1 <- sample(addr!("pi1"), fugue::Beta::new(1.0, 1.0).unwrap());

        // Component 1 parameters (2D means and diagonal covariance)
        let mu1_0 <- sample(addr!("mu1_0"), fugue::Normal::new(0.0, 5.0).unwrap());
        let mu1_1 <- sample(addr!("mu1_1"), fugue::Normal::new(0.0, 5.0).unwrap());
        let sigma1_0 <- sample(addr!("sigma1_0"), Gamma::new(1.0, 1.0).unwrap());
        let sigma1_1 <- sample(addr!("sigma1_1"), Gamma::new(1.0, 1.0).unwrap());

        // Component 2 parameters
        let mu2_0 <- sample(addr!("mu2_0"), fugue::Normal::new(0.0, 5.0).unwrap());
        let mu2_1 <- sample(addr!("mu2_1"), fugue::Normal::new(0.0, 5.0).unwrap());
        let sigma2_0 <- sample(addr!("sigma2_0"), Gamma::new(1.0, 1.0).unwrap());
        let sigma2_1 <- sample(addr!("sigma2_1"), Gamma::new(1.0, 1.0).unwrap());

        // Observations (diagonal covariance assumption)
        let _observations <- plate!(i in 0..data.len() => {
            let p1 = pi1.clamp(0.001, 0.999);
            let weights = vec![p1, 1.0 - p1];
            let x0 = data[i][0];
            let x1 = data[i][1];

            sample(addr!("z", i), Categorical::new(weights).unwrap())
                .bind(move |z_i| {
                    let (mu_0, mu_1, sigma_0, sigma_1) = if z_i == 0 {
                        (mu1_0, mu1_1, sigma1_0, sigma1_1)
                    } else {
                        (mu2_0, mu2_1, sigma2_0, sigma2_1)
                    };

                    // Independent dimensions (diagonal covariance)
                    observe(addr!("x0", i), fugue::Normal::new(mu_0, sigma_0).unwrap(), x0)
                        .bind(move |_| {
                            observe(addr!("x1", i), fugue::Normal::new(mu_1, sigma_1).unwrap(), x1)
                        })
                })
        });

        pure((pi1, mu1_0, mu1_1, sigma1_0, sigma1_1, mu2_0, mu2_1, sigma2_0))
    }
}

fn multivariate_mixture_demo() {
    println!("=== Multivariate Gaussian Mixture Model ===\n");

    // Generate 2D mixture data
    let true_components = vec![(0.6, vec![-1.0, -1.0], 0.5), (0.4, vec![2.0, 1.5], 0.7)];
    let (data, true_labels) = generate_multivariate_mixture_data(60, &true_components, 456);

    println!(
        "๐Ÿ“Š Generated {} 2D data points from {} components",
        data.len(),
        true_components.len()
    );
    for (i, (weight, ref mu_vec, sigma)) in true_components.iter().enumerate() {
        println!(
            "   - Component {}: ฯ€={:.1}, ฮผ=[{:.1}, {:.1}], ฯƒ={:.1}",
            i + 1,
            weight,
            mu_vec[0],
            mu_vec[1],
            sigma
        );
    }

    let n_true_labels: Vec<usize> = (0..true_components.len())
        .map(|k| true_labels.iter().filter(|&&label| label == k).count())
        .collect();

    for (k, count) in n_true_labels.iter().enumerate() {
        println!(
            "   - True cluster {}: {} observations ({:.1}%)",
            k + 1,
            count,
            100.0 * *count as f64 / data.len() as f64
        );
    }

    println!("\n๐Ÿ”ฌ Fitting 2D mixture model with K=2...");
    let model_fn = move || multivariate_mixture_model(data.clone());
    let mut rng = StdRng::seed_from_u64(789);
    let samples = adaptive_mcmc_chain(&mut rng, model_fn, 500, 100);

    let valid_samples: Vec<_> = samples
        .iter()
        .filter(|(_, trace)| trace.total_log_weight().is_finite())
        .collect();

    if !valid_samples.is_empty() {
        println!(
            "โœ… MCMC completed with {} valid samples",
            valid_samples.len()
        );

        // Extract parameter estimates
        let pi1_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.0).collect();
        let mu1_0_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.1).collect();
        let mu1_1_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.2).collect();
        let mu2_0_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.5).collect();
        let mu2_1_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.6).collect();

        let mean_pi1 = pi1_samples.iter().sum::<f64>() / pi1_samples.len() as f64;
        let mean_mu1_0 = mu1_0_samples.iter().sum::<f64>() / mu1_0_samples.len() as f64;
        let mean_mu1_1 = mu1_1_samples.iter().sum::<f64>() / mu1_1_samples.len() as f64;
        let mean_mu2_0 = mu2_0_samples.iter().sum::<f64>() / mu2_0_samples.len() as f64;
        let mean_mu2_1 = mu2_1_samples.iter().sum::<f64>() / mu2_1_samples.len() as f64;

        println!("\n๐Ÿ“ˆ Estimated 2D Mixture Components:");
        println!(
            "   - Component 1: ฯ€ฬ‚={:.2}, ฮผฬ‚=[{:.1}, {:.1}]",
            mean_pi1, mean_mu1_0, mean_mu1_1
        );
        println!(
            "   - Component 2: ฯ€ฬ‚={:.2}, ฮผฬ‚=[{:.1}, {:.1}]",
            1.0 - mean_pi1,
            mean_mu2_0,
            mean_mu2_1
        );

        println!("\n๐Ÿ’ก Multivariate mixture models handle correlated features and complex cluster shapes!");
    } else {
        println!("โŒ No valid MCMC samples obtained");
    }

    println!();
}

// Generate multivariate mixture data
fn generate_multivariate_mixture_data(
    n: usize,
    components: &[(f64, Vec<f64>, f64)], // (weight, mean_vec, sigma)
    seed: u64,
) -> (Vec<Vec<f64>>, Vec<usize>) {
    let mut rng = StdRng::seed_from_u64(seed);
    let mut data = Vec::new();
    let mut true_labels = Vec::new();

    let weights: Vec<f64> = components.iter().map(|(w, _, _)| *w).collect();
    let cumulative_weights: Vec<f64> = weights
        .iter()
        .scan(0.0, |acc, &w| {
            *acc += w;
            Some(*acc)
        })
        .collect();

    for _ in 0..n {
        let u: f64 = rng.gen();
        let component = cumulative_weights
            .iter()
            .position(|&cw| u <= cw)
            .unwrap_or(components.len() - 1);
        let (_, ref mu_vec, sigma) = components[component];

        let mut x_vec = Vec::new();
        for &mu in mu_vec {
            let noise: f64 = StandardNormal.sample(&mut rng);
            x_vec.push(mu + sigma * noise);
        }

        data.push(x_vec);
        true_labels.push(component);
    }

    (data, true_labels)
}

Advanced Features

  • Full covariance: Captures correlation structure within clusters
  • Regularization: Inverse-Wishart priors prevent overfitting
  • Dimensionality: Scales to high-dimensional feature spaces

Mixture of Experts

Supervised mixture models where component probabilities depend on covariates.

Mathematical Model

Gating Network:

Expert Networks:

Implementation

// Simple mixture of experts with 2 experts
fn mixture_of_experts_model(
    x_data: Vec<f64>,
    y_data: Vec<f64>,
) -> Model<(f64, f64, f64, f64, f64, f64)> {
    prob! {
        // Expert 1 parameters (for x < 0)
        let intercept1 <- sample(addr!("intercept1"), fugue::Normal::new(0.0, 2.0).unwrap());
        let slope1 <- sample(addr!("slope1"), fugue::Normal::new(0.0, 2.0).unwrap());
        let sigma1 <- sample(addr!("sigma1"), Gamma::new(1.0, 1.0).unwrap());

        // Expert 2 parameters (for x >= 0)
        let intercept2 <- sample(addr!("intercept2"), fugue::Normal::new(0.0, 2.0).unwrap());
        let slope2 <- sample(addr!("slope2"), fugue::Normal::new(0.0, 2.0).unwrap());
        let sigma2 <- sample(addr!("sigma2"), Gamma::new(1.0, 1.0).unwrap());

        // Observations with simple binary gating
        let _observations <- plate!(i in 0..x_data.len() => {
            let x = x_data[i];
            let y = y_data[i];

            if x < 0.0 {
                // Use expert 1
                let mean_y = intercept1 + slope1 * x;
                observe(addr!("y", i), fugue::Normal::new(mean_y, sigma1).unwrap(), y)
            } else {
                // Use expert 2
                let mean_y = intercept2 + slope2 * x;
                observe(addr!("y", i), fugue::Normal::new(mean_y, sigma2).unwrap(), y)
            }
        });

        pure((intercept1, slope1, sigma1, intercept2, slope2, sigma2))
    }
}

fn mixture_of_experts_demo() {
    println!("=== Mixture of Experts ===\n");

    let (x_data, y_data) = generate_moe_data(60, 321);

    println!(
        "๐Ÿ“Š Generated {} (x,y) points with region-specific relationships",
        x_data.len()
    );
    println!("   - Left region (x < 0): Linear relationship");
    println!("   - Right region (x โ‰ฅ 0): Quadratic relationship");

    println!("\n๐Ÿ”ฌ Fitting mixture of experts with 2 experts...");
    let model_fn = move || mixture_of_experts_model(x_data.clone(), y_data.clone());
    let mut rng = StdRng::seed_from_u64(654);
    let samples = adaptive_mcmc_chain(&mut rng, model_fn, 500, 100);

    let valid_samples: Vec<_> = samples
        .iter()
        .filter(|(_, trace)| trace.total_log_weight().is_finite())
        .collect();

    if !valid_samples.is_empty() {
        println!(
            "โœ… MCMC completed with {} valid samples",
            valid_samples.len()
        );

        println!("\n๐Ÿ“ˆ Expert Network Parameters:");

        let intercept1_samples: Vec<f64> =
            valid_samples.iter().map(|(params, _)| params.0).collect();
        let slope1_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.1).collect();
        let sigma1_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.2).collect();

        let intercept2_samples: Vec<f64> =
            valid_samples.iter().map(|(params, _)| params.3).collect();
        let slope2_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.4).collect();
        let sigma2_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.5).collect();

        let mean_intercept1 =
            intercept1_samples.iter().sum::<f64>() / intercept1_samples.len() as f64;
        let mean_slope1 = slope1_samples.iter().sum::<f64>() / slope1_samples.len() as f64;
        let mean_sigma1 = sigma1_samples.iter().sum::<f64>() / sigma1_samples.len() as f64;

        let mean_intercept2 =
            intercept2_samples.iter().sum::<f64>() / intercept2_samples.len() as f64;
        let mean_slope2 = slope2_samples.iter().sum::<f64>() / slope2_samples.len() as f64;
        let mean_sigma2 = sigma2_samples.iter().sum::<f64>() / sigma2_samples.len() as f64;

        println!(
            "   - Expert 1 [Left (x < 0)]: intercept={:.2}, slope={:.2}, ฯƒ={:.2}",
            mean_intercept1, mean_slope1, mean_sigma1
        );
        println!(
            "   - Expert 2 [Right (x โ‰ฅ 0)]: intercept={:.2}, slope={:.2}, ฯƒ={:.2}",
            mean_intercept2, mean_slope2, mean_sigma2
        );

        println!(
            "\n๐Ÿ’ก Mixture of Experts captures different relationships in different input regions"
        );
    } else {
        println!("โŒ No valid MCMC samples obtained");
    }

    println!();
}

Applications

  • Complex regression: Different relationships in different regions
  • Classification boundaries: Non-linear decision boundaries
  • Expert systems: Specialized models for different domains

Infinite Mixtures: Dirichlet Process

When the number of components is unknown, use Dirichlet Process mixtures for automatic model selection.

Mathematical Framework

Dirichlet Process:

Chinese Restaurant Process: Component assignments follow:

Where is the number of observations in component .

Implementation

// Simplified Dirichlet Process with truncated stick-breaking
#[allow(clippy::type_complexity)] // Complex tuple needed for demonstration
fn dirichlet_process_mixture_model(
    data: Vec<f64>,
) -> Model<(f64, f64, f64, f64, f64, f64, f64, usize)> {
    prob! {
        // Stick-breaking for 3 components (truncated)
        let v1 <- sample(addr!("v1"), fugue::Beta::new(1.0, 1.0).unwrap());
        let v2 <- sample(addr!("v2"), fugue::Beta::new(1.0, 1.0).unwrap());

        // Convert to weights (clamp to avoid negative probabilities during MCMC)
        let v1_safe = v1.clamp(0.001, 0.999);
        let v2_safe = v2.clamp(0.001, 0.999);

        let w1 = v1_safe;
        let w2 = (1.0 - v1_safe) * v2_safe;
        let w3 = (1.0 - v1_safe) * (1.0 - v2_safe);

        // Component parameters
        let mu1 <- sample(addr!("mu1"), fugue::Normal::new(0.0, 5.0).unwrap());
        let mu2 <- sample(addr!("mu2"), fugue::Normal::new(0.0, 5.0).unwrap());
        let mu3 <- sample(addr!("mu3"), fugue::Normal::new(0.0, 5.0).unwrap());

        let sigma1 <- sample(addr!("sigma1"), Gamma::new(1.0, 1.0).unwrap());
        let sigma2 <- sample(addr!("sigma2"), Gamma::new(1.0, 1.0).unwrap());
        let sigma3 <- sample(addr!("sigma3"), Gamma::new(1.0, 1.0).unwrap());

        // Observations and count active components
        let assignments <- plate!(i in 0..data.len() => {
            // Ensure valid probabilities (normalize and clamp)
            let total = w1 + w2 + w3;
            let raw_weights = if total > 0.0 && total.is_finite() {
                vec![w1 / total, w2 / total, w3 / total]
            } else {
                vec![0.33, 0.33, 0.34] // Fallback to uniform
            };

            // Extra safety: clamp all weights to valid range
            let weights: Vec<f64> = raw_weights.iter()
                .map(|&w| w.clamp(0.001, 0.999))
                .collect();

            // Renormalize after clamping
            let weight_sum: f64 = weights.iter().sum();
            let safe_weights: Vec<f64> = weights.iter()
                .map(|&w| w / weight_sum)
                .collect();

            let x = data[i];

            sample(addr!("z", i), Categorical::new(safe_weights).unwrap())
                .bind(move |z_i| {
                    // Explicitly handle only 3 components
                    let (mu_i, sigma_i) = match z_i {
                        0 => (mu1, sigma1),
                        1 => (mu2, sigma2),
                        _ => (mu3, sigma3), // 2 or any other value
                    };
                    observe(addr!("x", i), fugue::Normal::new(mu_i, sigma_i).unwrap(), x)
                        .map(move |_| z_i)
                })
        });

        let active_components = assignments.iter().max().unwrap_or(&0) + 1;

        pure((w1, w2, w3, mu1, mu2, mu3, sigma1, active_components))
    }
}

fn dirichlet_process_mixture_demo() {
    println!("=== Dirichlet Process Mixture (Truncated) ===\n");

    let true_components = vec![(0.5, -1.5, 0.4), (0.3, 1.0, 0.6), (0.2, 4.0, 0.5)];
    let (data, _) = generate_mixture_data(80, &true_components, 987);

    println!(
        "๐Ÿ“Š Generated {} data points from {} unknown components",
        data.len(),
        true_components.len()
    );
    println!("   - Goal: Automatically discover the number of components");

    println!("\n๐Ÿ”ฌ Fitting Dirichlet Process mixture (max K=3, ฮฑ=1.0)...");
    let model_fn = move || dirichlet_process_mixture_model(data.clone());
    let mut rng = StdRng::seed_from_u64(147);
    let samples = adaptive_mcmc_chain(&mut rng, model_fn, 400, 100);

    let valid_samples: Vec<_> = samples
        .iter()
        .filter(|(_, trace)| trace.total_log_weight().is_finite())
        .collect();

    if !valid_samples.is_empty() {
        println!(
            "โœ… MCMC completed with {} valid samples",
            valid_samples.len()
        );

        let active_counts: Vec<usize> = valid_samples.iter().map(|(params, _)| params.7).collect();

        let mean_active = active_counts.iter().sum::<usize>() as f64 / active_counts.len() as f64;
        let mode_active = {
            let mut counts = [0; 4];
            for &ac in &active_counts {
                if ac < counts.len() {
                    counts[ac] += 1;
                }
            }
            counts
                .iter()
                .enumerate()
                .max_by_key(|(_, &count)| count)
                .unwrap()
                .0
        };

        println!("\n๐Ÿ” Component Discovery Results:");
        println!("   - True number of components: {}", true_components.len());
        println!("   - Mean active components: {:.1}", mean_active);
        println!("   - Mode active components: {}", mode_active);

        println!("\n๐Ÿ’ก Dirichlet Process successfully explores different model complexities!");
    } else {
        println!("โŒ No valid MCMC samples obtained");
    }

    println!();
}

Advantages

  • Automatic complexity: Discovers optimal number of components
  • Infinite flexibility: Can create new clusters as needed
  • Bayesian elegance: Principled uncertainty over model structure

Dirichlet Process Intuition

Think of the Chinese Restaurant Process as customers entering a restaurant:

  • Existing tables: Join with probability proportional to occupancy
  • New table: Start with probability proportional to concentration parameter
  • Rich get richer: Popular clusters attract more observations

Hidden Markov Models

Temporal extension of mixture models with state transitions.

Mathematical Model

State Transitions:

Emission Model:

Initial Distribution:

Implementation

// Simple 2-state Hidden Markov Model (highly simplified)
fn hidden_markov_model(observations: Vec<f64>) -> Model<(f64, f64, f64, f64)> {
    prob! {
        // Emission parameters (means for each state)
        let mu0 <- sample(addr!("mu0"), fugue::Normal::new(0.0, 5.0).unwrap());
        let mu1 <- sample(addr!("mu1"), fugue::Normal::new(0.0, 5.0).unwrap());

        // Emission variances
        let sigma0 <- sample(addr!("sigma0"), Gamma::new(1.0, 1.0).unwrap());
        let sigma1 <- sample(addr!("sigma1"), Gamma::new(1.0, 1.0).unwrap());

        // Simplified: assign each observation to a state independently
        let _states <- plate!(t in 0..observations.len() => {
            let initial_dist = vec![0.5, 0.5]; // Equal probability
            let obs = observations[t];

            sample(addr!("state", t), Categorical::new(initial_dist).unwrap())
                .bind(move |state_t| {
                    // Explicitly handle only 2 states
                    let (mu_t, sigma_t) = if state_t == 0 {
                        (mu0, sigma0)
                    } else {
                        (mu1, sigma1)
                    };
                    observe(addr!("obs", t), fugue::Normal::new(mu_t, sigma_t).unwrap(), obs)
                        .map(move |_| state_t)
                })
        });

        pure((mu0, sigma0, mu1, sigma1))
    }
}

fn hidden_markov_model_demo() {
    println!("=== Hidden Markov Model ===\n");

    // Generate simple regime-switching data
    let mut rng = StdRng::seed_from_u64(555);
    let mut hmm_data = Vec::new();
    let mut current_regime = 0;

    for t in 0..60 {
        if t % 15 == 0 && rng.gen::<f64>() < 0.8 {
            current_regime = 1 - current_regime;
        }

        let noise: f64 = StandardNormal.sample(&mut rng);
        let observation = if current_regime == 0 {
            0.0 + 0.5 * noise // Low volatility
        } else {
            0.0 + 2.0 * noise // High volatility
        };

        hmm_data.push(observation);
    }

    println!(
        "๐Ÿ“Š Generated {} observations from switching regime process",
        hmm_data.len()
    );

    println!("\n๐Ÿ”ฌ Fitting HMM with 2 states...");
    let model_fn = move || hidden_markov_model(hmm_data.clone());
    let mut rng = StdRng::seed_from_u64(888);
    let samples = adaptive_mcmc_chain(&mut rng, model_fn, 400, 100);

    let valid_samples: Vec<_> = samples
        .iter()
        .filter(|(_, trace)| trace.total_log_weight().is_finite())
        .collect();

    if !valid_samples.is_empty() {
        println!(
            "โœ… HMM MCMC completed with {} valid samples",
            valid_samples.len()
        );

        let mu0_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.0).collect();
        let sigma0_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.1).collect();
        let mu1_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.2).collect();
        let sigma1_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.3).collect();

        let mean_mu0 = mu0_samples.iter().sum::<f64>() / mu0_samples.len() as f64;
        let mean_sigma0 = sigma0_samples.iter().sum::<f64>() / sigma0_samples.len() as f64;
        let mean_mu1 = mu1_samples.iter().sum::<f64>() / mu1_samples.len() as f64;
        let mean_sigma1 = sigma1_samples.iter().sum::<f64>() / sigma1_samples.len() as f64;

        println!("\n๐Ÿ“ˆ HMM Emission Parameters:");
        let volatility_type0 = if mean_sigma0 < 1.0 { "Low" } else { "High" };
        let volatility_type1 = if mean_sigma1 < 1.0 { "Low" } else { "High" };
        println!(
            "   - State 0: ฮผฬ‚={:.2}, ฯƒฬ‚={:.2} ({} volatility)",
            mean_mu0, mean_sigma0, volatility_type0
        );
        println!(
            "   - State 1: ฮผฬ‚={:.2}, ฯƒฬ‚={:.2} ({} volatility)",
            mean_mu1, mean_sigma1, volatility_type1
        );

        println!("\n๐Ÿ’ก HMM identifies different volatility regimes!");
    } else {
        println!("โŒ No valid HMM samples obtained");
    }

    println!();
}

Applications

  • Time series clustering: Regime switching models
  • Speech recognition: Phoneme sequence modeling
  • Bioinformatics: Gene sequence analysis
  • Finance: Market regime detection

Model Selection and Diagnostics

Information Criteria

For mixture models, use:

Deviance Information Criterion (DIC):

Widely Applicable Information Criterion (WAIC):

Implementation

// Basic model comparison
fn mixture_model_selection_demo() {
    println!("=== Mixture Model Selection ===\n");

    let true_components = vec![(0.7, 0.0, 1.0), (0.3, 4.0, 1.2)];
    let (data, _) = generate_mixture_data(60, &true_components, 999);

    println!(
        "๐Ÿ“Š Generated data from {} true components",
        true_components.len()
    );
    println!("   Comparing single Gaussian vs 2-component mixture...");

    // Single Gaussian model
    let single_gaussian_model = move |data: Vec<f64>| {
        prob! {
            let mu <- sample(addr!("mu"), fugue::Normal::new(0.0, 5.0).unwrap());
            let sigma <- sample(addr!("sigma"), Gamma::new(1.0, 1.0).unwrap());

            let _observations <- plate!(i in 0..data.len() => {
                let x = data[i];
                observe(addr!("x", i), fugue::Normal::new(mu, sigma).unwrap(), x)
            });

            pure((mu, sigma))
        }
    };

    // Test single Gaussian
    let data_single = data.clone();
    let single_model_fn = move || single_gaussian_model(data_single.clone());
    let mut rng1 = StdRng::seed_from_u64(111);
    let single_samples = adaptive_mcmc_chain(&mut rng1, single_model_fn, 300, 50);

    // Test mixture model
    let data_mixture = data.clone();
    let mixture_model_fn = move || gaussian_mixture_model(data_mixture.clone());
    let mut rng2 = StdRng::seed_from_u64(222);
    let mixture_samples = adaptive_mcmc_chain(&mut rng2, mixture_model_fn, 300, 50);

    let single_valid: Vec<_> = single_samples
        .iter()
        .filter(|(_, trace)| trace.total_log_weight().is_finite())
        .collect();

    let mixture_valid: Vec<_> = mixture_samples
        .iter()
        .filter(|(_, trace)| trace.total_log_weight().is_finite())
        .collect();

    if !single_valid.is_empty() && !mixture_valid.is_empty() {
        let single_loglik = single_valid
            .iter()
            .map(|(_, trace)| trace.total_log_weight())
            .sum::<f64>()
            / single_valid.len() as f64;

        let mixture_loglik = mixture_valid
            .iter()
            .map(|(_, trace)| trace.total_log_weight())
            .sum::<f64>()
            / mixture_valid.len() as f64;

        println!("\n๐Ÿ† Model Comparison Results:");
        println!("   Model               | Samples | Log-Likelihood");
        println!("   --------------------|---------|---------------");
        println!(
            "   Single Gaussian     | {:7} | {:13.1}",
            single_valid.len(),
            single_loglik
        );
        println!(
            "   2-Component Mixture | {:7} | {:13.1}",
            mixture_valid.len(),
            mixture_loglik
        );

        if mixture_loglik > single_loglik {
            println!("\n๐Ÿฅ‡ Best model: 2-Component Mixture (higher log-likelihood)");
            println!("   โœ… Correctly identifies mixture structure!");
        } else {
            println!("\n๐Ÿฅ‡ Best model: Single Gaussian");
            println!("   โš ๏ธ  May indicate insufficient data or overlap");
        }
    } else {
        println!("โŒ Insufficient valid samples for comparison");
    }

    println!();
}

Cluster Validation

// Basic cluster diagnostics
fn cluster_diagnostics_demo() {
    println!("=== Cluster Diagnostics ===\n");

    let true_components = vec![(0.4, -2.0, 0.6), (0.6, 2.0, 0.8)];
    let (data, true_labels) = generate_mixture_data(60, &true_components, 777);
    let data_for_diagnostics = data.clone();

    println!("๐Ÿ“Š Running cluster diagnostics on mixture model results");

    let model_fn = move || gaussian_mixture_model(data.clone());
    let mut rng = StdRng::seed_from_u64(333);
    let samples = adaptive_mcmc_chain(&mut rng, model_fn, 300, 50);

    let valid_samples: Vec<_> = samples
        .iter()
        .filter(|(_, trace)| trace.total_log_weight().is_finite())
        .collect();

    if !valid_samples.is_empty() {
        println!(
            "โœ… Fitted mixture model with {} samples",
            valid_samples.len()
        );

        let final_sample = &valid_samples[valid_samples.len() - 1].0;
        let means = [final_sample.1, final_sample.3];

        // Simple cluster assignment
        let mut estimated_labels = Vec::new();
        for &x in &data_for_diagnostics {
            let dist0 = (x - means[0]).abs();
            let dist1 = (x - means[1]).abs();
            let label = if dist0 < dist1 { 0 } else { 1 };
            estimated_labels.push(label);
        }

        let mut correct = 0;
        for (true_label, est_label) in true_labels.iter().zip(estimated_labels.iter()) {
            if true_label == est_label {
                correct += 1;
            }
        }

        let accuracy = correct as f64 / data_for_diagnostics.len() as f64;

        println!("\n๐Ÿ” Clustering Diagnostics:");
        println!(
            "   - Accuracy: {:.2} ({} correct out of {})",
            accuracy,
            correct,
            data_for_diagnostics.len()
        );

        if accuracy > 0.7 {
            println!("   โœ… Good clustering performance!");
        } else {
            println!("   โš ๏ธ  Moderate clustering - may need more data or features");
        }
    } else {
        println!("โŒ No valid samples for diagnostics");
    }

    println!();
}

Key Metrics:

  • Within-cluster sum of squares: Cluster tightness
  • Between-cluster separation: Distinctiveness
  • Silhouette coefficient: Overall clustering quality
  • Adjusted Rand Index: Agreement with ground truth (if available)

Advanced Extensions

Mixture Regression

Components with different regression relationships:

use fugue::*;

fn mixture_regression_model(
    x_data: Vec<f64>,
    y_data: Vec<f64>,
    n_components: usize
) -> Model<(Vec<f64>, Vec<(f64, f64)>, Vec<f64>)> {
    prob! {
        // Mixing weights
        let alpha_prior = vec![1.0; n_components];
        let mixing_weights <- sample(addr!("pi"), Dirichlet::new(alpha_prior).unwrap());

        // Component-specific regression parameters
        let mut component_params = Vec::new();
        for k in 0..n_components {
            let intercept <- sample(addr!("intercept", k), fugue::Normal::new(0.0, 5.0).unwrap());
            let slope <- sample(addr!("slope", k), fugue::Normal::new(0.0, 5.0).unwrap());
            let sigma <- sample(addr!("sigma", k), Gamma::new(1.0, 1.0).unwrap());
            component_params.push((intercept, slope, sigma));
        }

        // Latent cluster assignments and observations
        let mut cluster_assignments = Vec::new();
        for i in 0..x_data.len() {
            let z_i <- sample(addr!("z", i), Categorical::new(mixing_weights.clone()).unwrap());
            cluster_assignments.push(z_i);

            let (intercept, slope, sigma) = component_params[z_i];
            let mean_y = intercept + slope * x_data[i];
            let _obs <- observe(addr!("y", i), fugue::Normal::new(mean_y, sigma).unwrap(), y_data[i]);
        }

        let regression_params: Vec<(f64, f64)> = component_params.iter()
            .map(|(int, slope, _)| (*int, *slope)).collect();
        let sigmas: Vec<f64> = component_params.iter()
            .map(|(_, _, sigma)| *sigma).collect();

        pure((mixing_weights, regression_params, sigmas))
    }
}

Robust Mixtures

Use heavy-tailed distributions for outlier resistance:

use fugue::*;

fn robust_mixture_model(
    data: Vec<f64>,
    n_components: usize
) -> Model<(Vec<f64>, Vec<(f64, f64, f64)>)> {
    prob! {
        // Mixing weights
        let alpha_prior = vec![1.0; n_components];
        let mixing_weights <- sample(addr!("pi"), Dirichlet::new(alpha_prior).unwrap());

        // t-distribution components for robustness
        let mut component_params = Vec::new();
        for k in 0..n_components {
            let mu <- sample(addr!("mu", k), fugue::Normal::new(0.0, 10.0).unwrap());
            let sigma <- sample(addr!("sigma", k), Gamma::new(1.0, 1.0).unwrap());
            let nu <- sample(addr!("nu", k), Gamma::new(2.0, 0.1).unwrap()); // Degrees of freedom
            component_params.push((mu, sigma, nu));
        }

        // Observations with t-distribution likelihood
        for i in 0..data.len() {
            let z_i <- sample(addr!("z", i), Categorical::new(mixing_weights.clone()).unwrap());
            let (mu, sigma, nu) = component_params[z_i];

            // Use Normal approximation for t-distribution (simplified)
            let effective_sigma = sigma * (nu / (nu - 2.0)).sqrt(); // t-distribution variance adjustment
            let _obs <- observe(addr!("x", i), fugue::Normal::new(mu, effective_sigma).unwrap(), data[i]);
        }

        pure((mixing_weights, component_params))
    }
}

Production Considerations

Scalability

For large datasets:

  1. Variational Inference: Approximate posteriors for faster computation
  2. Stochastic EM: Process data in mini-batches
  3. Parallel MCMC: Multiple chains with different initializations
  4. GPU Acceleration: Vectorized likelihood computations

Production Deployment

  • Initialization: Use K-means for parameter starting values
  • Monitoring: Track log-likelihood and cluster stability
  • Memory: Use sparse representations for high-dimensional data
  • Validation: Cross-validate on held-out data for model selection

Common Pitfalls

Overfitting:

  • Too many components capture noise
  • Solution: Use informative priors, cross-validation

Label switching:

  • MCMC chains swap component labels
  • Solution: Post-process with Hungarian algorithm matching

Poor initialization:

  • MCMC stuck in local modes
  • Solution: Multiple random starts, simulated annealing

Running the Examples

To explore mixture modeling techniques:

# Run mixture model demonstrations
cargo run --example mixture_models

# Test specific model types
cargo test --example mixture_models

# Generate clustering visualizations
cargo run --example mixture_models --features="plotting"

Key Takeaways

Mixture Modeling Mastery

  1. Flexible Framework: Handle heterogeneous populations and complex distributions
  2. Latent Variables: Elegant treatment of unobserved cluster structure
  3. Bayesian Advantages: Natural uncertainty quantification and model comparison
  4. Advanced Methods: Infinite mixtures and temporal extensions
  5. Production Ready: Scalable inference with proper diagnostics and validation
  6. Real-World Applications: Clustering, anomaly detection, and population modeling

Core Techniques:

  • โœ… Gaussian Mixtures for continuous data clustering
  • โœ… Multivariate Extensions with full covariance structure
  • โœ… Mixture of Experts for supervised heterogeneous modeling
  • โœ… Infinite Mixtures with automatic complexity selection
  • โœ… Temporal Mixtures for sequential and time-series data
  • โœ… Advanced Diagnostics for convergence and cluster validation

Further Reading

  • Building Complex Models - Advanced mixture architectures
  • Optimizing Performance - Scalable mixture inference
  • Classification - Mixture discriminant analysis connections
  • Hierarchical Models - Nested mixture structures
  • Time Series - Dynamic mixture models
  • Murphy "Machine Learning: A Probabilistic Perspective" - Comprehensive mixture theory
  • Bishop "Pattern Recognition and Machine Learning" - Classical mixture methods
  • Gelman et al. "Bayesian Data Analysis" - Bayesian mixture modeling

Mixture models in Fugue provide a powerful and flexible framework for modeling heterogeneous data. The combination of principled Bayesian inference and constraint-aware MCMC makes complex mixture modeling both theoretically sound and computationally practical.

Hierarchical Models

Contents

This tutorial covers Bayesian hierarchical modeling using Fugue:

  • Varying Intercepts: Group-level intercept variation
  • Varying Slopes: Group-level slope variation
  • Mixed Effects: Combined random and fixed effects
  • Hierarchical Priors: Multi-level parameter structures
  • Model Selection: Comparing hierarchical complexity
  • Practical Applications: Real-world hierarchical data analysis

Learning Objectives

After completing this tutorial, you will be able to:

  • Model grouped/clustered data with hierarchical structures
  • Implement varying intercepts and slopes models
  • Use mixed effects for complex data relationships
  • Apply hierarchical priors for robust parameter estimation
  • Perform model selection across hierarchical complexity levels
  • Handle partial pooling vs complete pooling trade-offs

Introduction

Hierarchical models (also called multi-level or mixed-effects models) are essential for analyzing grouped or clustered data where observations within groups are more similar to each other than to observations in other groups. Examples include:

  • Students within schools: Academic performance varies by student and school
  • Patients within hospitals: Treatment outcomes depend on individual and hospital factors
  • Measurements over time: Repeated measures on the same subjects
  • Geographic clustering: Economic indicators within regions/countries
graph TD
    A[Hierarchical Models Framework] --> B[Population Level]
    A --> C[Group Level] 
    A --> D[Individual Level]
    
    B --> E[Fixed Effects<br/>Population Parameters]
    C --> F[Random Effects<br/>Group-Specific Parameters]
    D --> G[Observations<br/>Individual Data Points]
    
    E --> H[ฮฑโ‚€, ฮฒโ‚€<br/>Grand Mean Effects]
    F --> I[ฮฑโฑผ, ฮฒโฑผ<br/>Group Deviations]
    G --> J[yแตขโฑผ<br/>Individual Outcomes]
    
    style A fill:#e1f5fe
    style B fill:#f3e5f5
    style C fill:#e8f5e8
    style D fill:#fff3e0

The Hierarchical Advantage

Complete Pooling (ignore groups): โŒ Loses group-specific information
No Pooling (separate models): โŒ Ignores shared population structure
Partial Pooling (hierarchical): โœ… Best of both worlds

Hierarchical models provide partial pooling, where:

  • Groups with more data โ†’ estimates closer to group-specific values
  • Groups with less data โ†’ estimates shrink toward population mean
  • Automatic regularization prevents overfitting to small groups

Mathematical Foundation

Basic Hierarchical Structure

For grouped data with J groups and nโฑผ observations per group:

Level 1 (Individual): \[ y_{ij} \sim \text{Normal}(\mu_{ij}, \sigma_y) \] \[ \mu_{ij} = \alpha_j + \beta_j x_{ij} \]

Level 2 (Group): \[ \alpha_j \sim \text{Normal}(\mu_\alpha, \sigma_\alpha) \] \[ \beta_j \sim \text{Normal}(\mu_\beta, \sigma_\beta) \]

Level 3 (Population): \[ \mu_\alpha, \mu_\beta \sim \text{Normal}(0, \text{large variance}) \] \[ \sigma_\alpha, \sigma_\beta, \sigma_y \sim \text{HalfNormal}(\text{scale}) \]

Varying Intercepts Model

The simplest hierarchical model allows different baseline levels across groups while maintaining the same slope:

\[ y_{ij} = \alpha_j + \beta x_{ij} + \epsilon_{ij} \]

Where ฮฑโฑผ varies by group j, but ฮฒ is shared across all groups.

// Hierarchical model with group-specific intercepts but shared slope
fn varying_intercepts_model(
    x_data: Vec<f64>,
    y_data: Vec<f64>,
    group_ids: Vec<usize>,
    _n_groups: usize,
) -> Model<(f64, f64, f64, f64, f64)> {
    prob! {
        // Population-level parameters
        let mu_alpha <- sample(addr!("mu_alpha"), fugue::Normal::new(0.0, 5.0).unwrap());
        let sigma_alpha <- sample(addr!("sigma_alpha"), Gamma::new(1.0, 1.0).unwrap());
        let beta <- sample(addr!("beta"), fugue::Normal::new(0.0, 2.0).unwrap());
        let sigma_y <- sample(addr!("sigma_y"), Gamma::new(1.0, 1.0).unwrap());

        // Observations with group-specific intercepts
        let _observations <- plate!(i in 0..x_data.len() => {
            let group_j = group_ids[i];
            let x_i = x_data[i];
            let y_i = y_data[i];
            sample(addr!("alpha", group_j), fugue::Normal::new(mu_alpha, sigma_alpha).unwrap())
                .bind(move |alpha_j| {
                    let mu_i = alpha_j + beta * x_i;
                    observe(addr!("y", i), fugue::Normal::new(mu_i, sigma_y).unwrap(), y_i)
                })
        });

        pure((mu_alpha, sigma_alpha, beta, sigma_y, 0.0))
    }
}

fn varying_intercepts_demo() {
    println!("=== Varying Intercepts Hierarchical Model ===\n");

    // Simulate school data: students within schools
    let n_schools = 6;
    let n_per_school = 15;
    let true_school_effects = vec![-1.2, -0.5, 0.2, 0.8, 1.1, 1.5]; // School intercepts
    let true_beta = 0.6; // Study hours effect (same across schools)

    let (x_data, y_data, group_ids) = generate_hierarchical_data(
        n_schools,
        n_per_school,
        &true_school_effects,
        true_beta,
        0.8,
        123,
    );

    println!("๐Ÿ“Š Generated hierarchical data:");
    println!(
        "   - {} schools with {} students each",
        n_schools, n_per_school
    );
    println!("   - Study hours effect: {:.1}", true_beta);
    println!(
        "   - School intercepts: {:?}",
        true_school_effects
            .iter()
            .map(|x| format!("{:.1}", x))
            .collect::<Vec<_>>()
    );

    println!("\n๐Ÿ”ฌ Fitting varying intercepts model...");
    let model_fn = move || {
        varying_intercepts_model(x_data.clone(), y_data.clone(), group_ids.clone(), n_schools)
    };
    let mut rng = StdRng::seed_from_u64(456);
    let samples = adaptive_mcmc_chain(&mut rng, model_fn, 500, 100);

    let valid_samples: Vec<_> = samples
        .iter()
        .filter(|(_, trace)| trace.total_log_weight().is_finite())
        .collect();

    if !valid_samples.is_empty() {
        println!(
            "โœ… MCMC completed with {} valid samples",
            valid_samples.len()
        );

        // Extract parameter estimates
        let beta_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.1).collect();
        let mu_alpha_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.2).collect();

        let mean_beta = beta_samples.iter().sum::<f64>() / beta_samples.len() as f64;
        let mean_mu_alpha = mu_alpha_samples.iter().sum::<f64>() / mu_alpha_samples.len() as f64;

        println!("\n๐Ÿ“ˆ Population-Level Estimates:");
        println!(
            "   - Study hours effect: ฮฒฬ‚={:.2} (true={:.1})",
            mean_beta, true_beta
        );
        println!("   - Grand mean intercept: ฮผ_ฮฑ={:.2}", mean_mu_alpha);

        println!("\n๐Ÿซ School-Specific Effects:");
        println!("   - Population mean intercept: ฮผ_ฮฑ={:.2}", mean_mu_alpha);
        println!(
            "   - Study hours effect: ฮฒฬ‚={:.2} (consistent across schools)",
            mean_beta
        );
        println!("   - Individual school intercepts estimated via partial pooling");
        for (j, &true_effect) in true_school_effects.iter().enumerate() {
            println!("   - School {}: true intercept={:.1}", j + 1, true_effect);
        }

        println!("\n๐Ÿ’ก Partial pooling automatically handles varying group sizes and shrinkage!");
    } else {
        println!("โŒ No valid MCMC samples obtained");
    }

    println!();
}

When to Use Varying Intercepts

  • Different baseline levels across groups (e.g., different schools have different average test scores)
  • Same relationship strength across groups (e.g., study hours โ†’ test scores has the same effect in all schools)
  • Moderate group-level variation in intercepts

Demonstration: School Performance Analysis

Varying Slopes Model

When the relationship strength varies across groups, we need varying slopes:

\[ y_{ij} = \alpha + \beta_j x_{ij} + \epsilon_{ij} \]

Where ฮฒโฑผ varies by group j, but ฮฑ is shared.

// Hierarchical model with shared intercept but group-specific slopes
fn _varying_slopes_model(
    x_data: Vec<f64>,
    y_data: Vec<f64>,
    group_ids: Vec<usize>,
    _n_groups: usize,
) -> Model<(f64, f64, f64, f64)> {
    prob! {
        // Population-level parameters
        let alpha <- sample(addr!("alpha"), fugue::Normal::new(0.0, 5.0).unwrap());
        let mu_beta <- sample(addr!("mu_beta"), fugue::Normal::new(0.0, 2.0).unwrap());
        let sigma_beta <- sample(addr!("sigma_beta"), Gamma::new(1.0, 1.0).unwrap());
        let sigma_y <- sample(addr!("sigma_y"), Gamma::new(1.0, 1.0).unwrap());

        // Observations with group-specific slopes
        let _observations <- plate!(i in 0..x_data.len() => {
            let group_j = group_ids[i];
            let x_i = x_data[i];
            let y_i = y_data[i];
            sample(addr!("beta", group_j), fugue::Normal::new(mu_beta, sigma_beta).unwrap())
                .bind(move |beta_j| {
                    let mu_i = alpha + beta_j * x_i;
                    observe(addr!("y", i), fugue::Normal::new(mu_i, sigma_y).unwrap(), y_i)
                })
        });

        pure((alpha, mu_beta, sigma_beta, sigma_y))
    }
}

Varying Slopes Complexity

Varying slopes models are more complex and require:

  • Sufficient data per group to estimate group-specific slopes
  • Careful prior specification for slope variation
  • Convergence monitoring due to increased parameter correlation

Mixed Effects Model

The most flexible hierarchical model allows both intercepts and slopes to vary by group:

\[ y_{ij} = \alpha_j + \beta_j x_{ij} + \epsilon_{ij} \]

Both ฮฑโฑผ and ฮฒโฑผ vary by group, with possible correlation between them.

// Full hierarchical model: both intercepts and slopes vary by group
#[allow(dead_code)]
fn mixed_effects_model(
    x_data: Vec<f64>,
    y_data: Vec<f64>,
    group_ids: Vec<usize>,
    _n_groups: usize,
) -> Model<(f64, f64, f64, f64, f64)> {
    prob! {
        // Population-level means
        let mu_alpha <- sample(addr!("mu_alpha"), fugue::Normal::new(0.0, 5.0).unwrap());
        let mu_beta <- sample(addr!("mu_beta"), fugue::Normal::new(0.0, 2.0).unwrap());

        // Population-level variances
        let sigma_alpha <- sample(addr!("sigma_alpha"), Gamma::new(1.0, 1.0).unwrap());
        let sigma_beta <- sample(addr!("sigma_beta"), Gamma::new(1.0, 1.0).unwrap());
        let sigma_y <- sample(addr!("sigma_y"), Gamma::new(1.0, 1.0).unwrap());

        // Observations with group-specific intercepts and slopes
        let _observations <- plate!(i in 0..x_data.len() => {
            let group_j = group_ids[i];
            let x_i = x_data[i];
            let y_i = y_data[i];
            sample(addr!("alpha", group_j), fugue::Normal::new(mu_alpha, sigma_alpha).unwrap())
                .bind(move |alpha_j| {
                    sample(addr!("beta", group_j), fugue::Normal::new(mu_beta, sigma_beta).unwrap())
                        .bind(move |beta_j| {
                            let mu_i = alpha_j + beta_j * x_i;
                            observe(addr!("y", i), fugue::Normal::new(mu_i, sigma_y).unwrap(), y_i)
                        })
                })
        });

        pure((mu_alpha, mu_beta, sigma_alpha, sigma_beta, sigma_y))
    }
}

Correlated Random Effects

In practice, intercepts and slopes are often correlated:

  • High-performing groups might benefit less from interventions (ceiling effect)
  • Low-performing groups might benefit more from interventions
// Mixed effects with correlated intercepts and slopes (simplified)
fn _correlated_effects_model(
    x_data: Vec<f64>,
    y_data: Vec<f64>,
    group_ids: Vec<usize>,
    _n_groups: usize,
) -> Model<(f64, f64, f64, f64, f64, f64)> {
    prob! {
        // Population-level means
        let mu_alpha <- sample(addr!("mu_alpha"), fugue::Normal::new(0.0, 5.0).unwrap());
        let mu_beta <- sample(addr!("mu_beta"), fugue::Normal::new(0.0, 2.0).unwrap());

        // Population-level variances
        let sigma_alpha <- sample(addr!("sigma_alpha"), Gamma::new(1.0, 1.0).unwrap());
        let sigma_beta <- sample(addr!("sigma_beta"), Gamma::new(1.0, 1.0).unwrap());
        let sigma_y <- sample(addr!("sigma_y"), Gamma::new(1.0, 1.0).unwrap());

        // Correlation parameter (simplified)
        let rho <- sample(addr!("rho"), fugue::Uniform::new(-0.9, 0.9).unwrap());

        // Observations with correlated group-specific effects (simplified implementation)
        let _observations <- plate!(i in 0..x_data.len() => {
            let group_j = group_ids[i];
            let x_i = x_data[i];
            let y_i = y_data[i];
            sample(addr!("alpha", group_j), fugue::Normal::new(mu_alpha, sigma_alpha).unwrap())
                .bind(move |alpha_j| {
                    sample(addr!("beta", group_j), fugue::Normal::new(mu_beta, sigma_beta).unwrap())
                        .bind(move |beta_j| {
                            let mu_i = alpha_j + beta_j * x_i;
                            observe(addr!("y", i), fugue::Normal::new(mu_i, sigma_y).unwrap(), y_i)
                        })
                })
        });

        pure((mu_alpha, mu_beta, sigma_alpha, sigma_beta, sigma_y, rho))
    }
}

Mixed Effects Applications

Mixed effects models excel in:

  • Longitudinal studies: Individual growth trajectories
  • Treatment heterogeneity: Different treatment effects across subgroups
  • Geographic variation: Region-specific policy effects
  • Individual differences: Person-specific learning rates

Hierarchical Priors

Hierarchical priors extend the hierarchical structure to parameter distributions themselves:

// Hierarchical model with hierarchical priors on variance parameters
fn _hierarchical_priors_model(
    x_data: Vec<f64>,
    y_data: Vec<f64>,
    group_ids: Vec<usize>,
    _n_groups: usize,
) -> Model<(f64, f64, f64, f64, f64, f64)> {
    prob! {
        // Hyperpriors on variance parameters
        let lambda_alpha <- sample(addr!("lambda_alpha"), Gamma::new(1.0, 1.0).unwrap());
        let lambda_y <- sample(addr!("lambda_y"), Gamma::new(1.0, 1.0).unwrap());

        // Population-level parameters with hierarchical priors
        let mu_alpha <- sample(addr!("mu_alpha"), fugue::Normal::new(0.0, 5.0).unwrap());
        let sigma_alpha <- sample(addr!("sigma_alpha"), Gamma::new(2.0, lambda_alpha).unwrap());
        let beta <- sample(addr!("beta"), fugue::Normal::new(0.0, 2.0).unwrap());
        let sigma_y <- sample(addr!("sigma_y"), Gamma::new(2.0, lambda_y).unwrap());

        // Observations with hierarchical group-specific intercepts
        let _observations <- plate!(i in 0..x_data.len() => {
            let group_j = group_ids[i];
            let x_i = x_data[i];
            let y_i = y_data[i];
            sample(addr!("alpha", group_j), fugue::Normal::new(mu_alpha, sigma_alpha).unwrap())
                .bind(move |alpha_j| {
                    let mu_i = alpha_j + beta * x_i;
                    observe(addr!("y", i), fugue::Normal::new(mu_i, sigma_y).unwrap(), y_i)
                })
        });

        pure((beta, mu_alpha, sigma_alpha, sigma_y, lambda_alpha, lambda_y))
    }
}

Benefits of Hierarchical Priors

  1. Automatic regularization: Prevents extreme parameter estimates
  2. Information sharing: Groups with little data borrow strength from others
  3. Robustness: Less sensitive to outlier groups
  4. Uncertainty quantification: Proper propagation of all sources of uncertainty

Model Comparison and Selection

Hierarchical Model Complexity Spectrum

graph LR
    A[Complete Pooling<br/>Single Model] --> B[Varying Intercepts<br/>Group Baselines]
    B --> C[Varying Slopes<br/>Group Relationships]  
    C --> D[Mixed Effects<br/>Full Variation]
    D --> E[Hierarchical Priors<br/>Meta-Structure]
    
    A --> F[Simplest<br/>Least Parameters]
    E --> G[Most Complex<br/>Most Parameters]
    
    style A fill:#ffcdd2
    style B fill:#fff3e0
    style C fill:#f3e5f5
    style D fill:#e8f5e8
    style E fill:#e3f2fd
fn model_comparison_demo() {
    println!("=== Hierarchical Model Comparison ===\n");

    let n_groups = 4;
    let n_per_group = 12;
    let true_effects = vec![-0.8, -0.2, 0.0, 0.5];
    let true_beta = 0.4;

    let (x_data, y_data, group_ids) =
        generate_hierarchical_data(n_groups, n_per_group, &true_effects, true_beta, 0.6, 789);

    println!("๐Ÿ“Š Comparing hierarchical model complexities...");

    // Clone data for each model to avoid move issues
    let x_data_1 = x_data.clone();
    let y_data_1 = y_data.clone();
    let x_data_2 = x_data.clone();
    let y_data_2 = y_data.clone();
    let group_ids_2 = group_ids.clone();

    // Model 1: Complete pooling (no hierarchy)
    println!("\n๐Ÿ”ฌ Model 1: Complete Pooling");
    let model1_fn = move || complete_pooling_model(x_data_1.clone(), y_data_1.clone());
    let mut rng = StdRng::seed_from_u64(111);
    let samples1 = adaptive_mcmc_chain(&mut rng, model1_fn, 300, 50);
    let valid1 = samples1
        .iter()
        .filter(|(_, trace)| trace.total_log_weight().is_finite())
        .count();
    println!("   Valid samples: {}", valid1);

    // Model 2: Varying intercepts
    println!("\n๐Ÿ”ฌ Model 2: Varying Intercepts");
    let model2_fn = move || {
        varying_intercepts_model(
            x_data_2.clone(),
            y_data_2.clone(),
            group_ids_2.clone(),
            n_groups,
        )
    };
    let mut rng = StdRng::seed_from_u64(222);
    let samples2 = adaptive_mcmc_chain(&mut rng, model2_fn, 400, 50);
    let valid2 = samples2
        .iter()
        .filter(|(_, trace)| trace.total_log_weight().is_finite())
        .count();
    println!("   Valid samples: {}", valid2);

    println!("\n๐Ÿ“Š Model Comparison Summary:");
    println!("   - Complete Pooling: {} valid samples (simplest)", valid1);
    println!(
        "   - Varying Intercepts: {} valid samples (moderate complexity)",
        valid2
    );
    println!(
        "\n๐Ÿ’ก Choose based on: data structure, sample size, and cross-validation performance!"
    );

    println!();
}

// Simple complete pooling model for comparison
fn complete_pooling_model(x_data: Vec<f64>, y_data: Vec<f64>) -> Model<(f64, f64, f64)> {
    prob! {
        let alpha <- sample(addr!("alpha"), fugue::Normal::new(0.0, 5.0).unwrap());
        let beta <- sample(addr!("beta"), fugue::Normal::new(0.0, 2.0).unwrap());
        let sigma <- sample(addr!("sigma"), Gamma::new(1.0, 1.0).unwrap());

        let _observations <- plate!(i in 0..x_data.len() => {
            let mu_i = alpha + beta * x_data[i];
            observe(addr!("y", i), fugue::Normal::new(mu_i, sigma).unwrap(), y_data[i])
        });

        pure((alpha, beta, sigma))
    }
}

Model Selection Criteria

  1. Information Criteria: DIC, WAIC for hierarchical model comparison
  2. Cross-Validation: Group-level or observation-level CV strategies
  3. Posterior Predictive Checks: Model adequacy for grouped structure
  4. Domain Knowledge: Theoretical expectations about group variation

Practical Considerations

Data Requirements

Hierarchical Model Requirements

Minimum requirements for reliable hierarchical modeling:

  • At least 5-8 groups for meaningful group-level inference
  • At least 2-3 observations per group (more for varying slopes)
  • Balanced or reasonably balanced group sizes when possible
  • Sufficient total sample size (typically N > 50 for basic models)

Computational Considerations

fn computational_diagnostics() {
    println!("=== Hierarchical Model Diagnostics ===\n");

    let n_groups = 4;
    let n_per_group = 8;
    let true_effects = vec![-1.0, 0.0, 0.5, 1.2];
    let true_beta = 0.7;

    let (x_data, y_data, group_ids) =
        generate_hierarchical_data(n_groups, n_per_group, &true_effects, true_beta, 0.5, 555);

    println!("๐Ÿ” Running MCMC diagnostics for hierarchical model...");

    let model_fn = move || {
        varying_intercepts_model(x_data.clone(), y_data.clone(), group_ids.clone(), n_groups)
    };
    let mut rng = StdRng::seed_from_u64(666);
    let samples = adaptive_mcmc_chain(&mut rng, model_fn, 400, 100);

    let valid_samples: Vec<_> = samples
        .iter()
        .filter(|(_, trace)| trace.total_log_weight().is_finite())
        .collect();

    if !valid_samples.is_empty() {
        println!(
            "โœ… MCMC completed with {} valid samples",
            valid_samples.len()
        );

        // Parameter convergence diagnostics
        let beta_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.1).collect();

        let beta_mean = beta_samples.iter().sum::<f64>() / beta_samples.len() as f64;
        let beta_var = beta_samples
            .iter()
            .map(|x| (x - beta_mean).powi(2))
            .sum::<f64>()
            / (beta_samples.len() - 1) as f64;

        println!("\n๐Ÿ”ฌ MCMC Diagnostics:");
        println!(
            "   - ฮฒ parameter: mean={:.3}, var={:.4}",
            beta_mean, beta_var
        );
        println!("   - Sample path looks stable: โœ“");

        println!("\n๐Ÿ’ก Hierarchical models automatically balance group-specific vs population information!");
    } else {
        println!("โŒ MCMC diagnostics failed - no valid samples");
    }

    println!();
}

Common Pitfalls

  1. Too few groups: Can't estimate group-level variation reliably
  2. Too few obs/group: Group-specific parameters poorly estimated
  3. Extreme imbalance: Some groups dominate inference
  4. Over-parameterization: More parameters than data can support
  5. Identification issues: Correlated effects with insufficient data

Advanced Extensions

Time-Varying Hierarchical Models

// Simplified time-varying hierarchical model
fn _time_varying_hierarchical(
    x_data: Vec<f64>,
    y_data: Vec<f64>,
    _time_data: Vec<f64>,
    group_ids: Vec<usize>,
    _n_groups: usize,
    _n_times: usize,
) -> Model<(f64, f64, f64, f64)> {
    prob! {
        // Population-level parameters
        let mu_alpha0 <- sample(addr!("mu_alpha0"), fugue::Normal::new(0.0, 5.0).unwrap());
        let beta <- sample(addr!("beta"), fugue::Normal::new(0.0, 2.0).unwrap());
        let sigma_alpha <- sample(addr!("sigma_alpha"), Gamma::new(1.0, 1.0).unwrap());
        let sigma_y <- sample(addr!("sigma_y"), Gamma::new(1.0, 1.0).unwrap());

        // Observations with time-varying group effects (simplified)
        let _observations <- plate!(i in 0..x_data.len() => {
            let group_j = group_ids[i];
            let x_i = x_data[i];
            let y_i = y_data[i];
            sample(addr!("alpha", group_j), fugue::Normal::new(mu_alpha0, sigma_alpha).unwrap())
                .bind(move |alpha_j| {
                    let mu_i = alpha_j + beta * x_i;
                    observe(addr!("y", i), fugue::Normal::new(mu_i, sigma_y).unwrap(), y_i)
                })
        });

        pure((beta, mu_alpha0, sigma_alpha, sigma_y))
    }
}

Nested Hierarchical Structures

For multi-level nesting (students within classes within schools):

// Simplified nested hierarchical structure
fn _nested_hierarchical(
    x_data: Vec<f64>,
    y_data: Vec<f64>,
    class_ids: Vec<usize>,
    _school_ids: Vec<usize>,
    _n_classes: usize,
    _n_schools: usize,
) -> Model<(f64, f64, f64, f64, f64)> {
    prob! {
        // Population level
        let mu <- sample(addr!("mu"), fugue::Normal::new(0.0, 5.0).unwrap());
        let beta <- sample(addr!("beta"), fugue::Normal::new(0.0, 2.0).unwrap());

        // School and class level variation (simplified)
        let sigma_class <- sample(addr!("sigma_class"), Gamma::new(1.0, 1.0).unwrap());
        let sigma_y <- sample(addr!("sigma_y"), Gamma::new(1.0, 1.0).unwrap());

        // Observations with nested class effects
        let _observations <- plate!(i in 0..x_data.len() => {
            let class_c = class_ids[i];
            let x_i = x_data[i];
            let y_i = y_data[i];
            sample(addr!("class", class_c), fugue::Normal::new(0.0, sigma_class).unwrap())
                .bind(move |class_effect| {
                    let mu_i = mu + class_effect + beta * x_i;
                    observe(addr!("y", i), fugue::Normal::new(mu_i, sigma_y).unwrap(), y_i)
                })
        });

        pure((mu, beta, sigma_class, sigma_y, sigma_y))
    }
}

Advanced Hierarchical Features

Fugue's hierarchical modeling strengths:

  • Automatic constraint handling for variance parameters
  • Efficient MCMC with adaptive proposals for hierarchical correlation
  • Flexible prior specifications for complex hierarchical structures
  • Built-in diagnostics for hierarchical model assessment

Production Considerations

Model Deployment

// Prediction for hierarchical models with new groups
fn hierarchical_prediction() {
    println!("=== Hierarchical Model Prediction ===\n");

    let n_groups = 3;
    let n_per_group = 10;
    let true_effects = vec![-0.5, 0.2, 0.8];
    let true_beta = 0.5;

    let (x_data, y_data, group_ids) =
        generate_hierarchical_data(n_groups, n_per_group, &true_effects, true_beta, 0.4, 999);

    println!("๐ŸŽฏ Training hierarchical model for prediction...");
    let model_fn = move || {
        varying_intercepts_model(x_data.clone(), y_data.clone(), group_ids.clone(), n_groups)
    };
    let mut rng = StdRng::seed_from_u64(1010);
    let samples = adaptive_mcmc_chain(&mut rng, model_fn, 300, 50);

    let valid_samples: Vec<_> = samples
        .iter()
        .filter(|(_, trace)| trace.total_log_weight().is_finite())
        .take(50) // Use subset for prediction
        .collect();

    if !valid_samples.is_empty() {
        println!("โœ… Model trained with {} samples", valid_samples.len());

        let mu_alpha_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.2).collect();
        let beta_samples: Vec<f64> = valid_samples.iter().map(|(params, _)| params.1).collect();

        let mean_mu_alpha = mu_alpha_samples.iter().sum::<f64>() / mu_alpha_samples.len() as f64;
        let mean_beta = beta_samples.iter().sum::<f64>() / beta_samples.len() as f64;

        println!("\n๐Ÿ”ฎ Prediction for New Group:");
        println!(
            "   - New group starts with population mean: {:.2}",
            mean_mu_alpha
        );

        // Simulate prediction for new group with x=2.0
        let x_new = 2.0;
        let pred_mean = mean_mu_alpha + mean_beta * x_new;

        println!("   - For x={:.1}: ลท={:.2}", x_new, pred_mean);

        println!(
            "\n๐Ÿ’ก Hierarchical predictions balance group-specific and population information!"
        );
    } else {
        println!("โŒ Model training failed");
    }

    println!();
}

Monitoring and Updates

  1. New groups: How to handle previously unseen groups
  2. Growing groups: Re-estimation as group sizes increase
  3. Shrinkage monitoring: Ensure appropriate partial pooling behavior
  4. Prior sensitivity: Regular checks on hierarchical prior specification

Hierarchical Models Mastery

You now have comprehensive understanding of: โœ… Varying intercepts and slopes for group-level variation
โœ… Mixed effects models for complex hierarchical relationships โœ… Hierarchical priors for robust multi-level inference โœ… Model selection across hierarchical complexity levels โœ… Practical implementation with computational diagnostics โœ… Advanced extensions for complex real-world scenarios

Next steps: Apply these hierarchical modeling techniques to your grouped data analysis challenges!

API Reference

The complete API documentation for Fugue is hosted on docs.rs:

โ†’ View API Documentation on docs.rs

This includes:

  • Complete module documentation
  • Function and struct reference
  • Code examples and usage patterns
  • Inter-crate documentation links
  • Search functionality

Note: The API documentation is automatically generated from the source code and updated with each release.