Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Classification

A comprehensive guide to Bayesian classification using Fugue. This tutorial demonstrates how to build, analyze, and extend classification models for discrete outcomes, showcasing the power of probabilistic programming for uncertainty quantification and principled model selection.

Learning Objectives

By the end of this tutorial, you will understand:

  • Binary Classification: Logistic regression with posterior uncertainty for two-class problems
  • Multi-class Classification: Multinomial logistic models and one-vs-rest approaches
  • Hierarchical Classification: Group-level effects for nested data structures
  • Model Comparison: Bayesian information criteria and Bayes factors for model selection
  • Uncertainty Quantification: Extracting and interpreting prediction confidence intervals
  • Robust Methods: Constraint-aware MCMC for stable parameter estimation
  • Production Applications: Scalable classification workflows for real-world deployment

The Classification Framework

Classification problems involve predicting discrete outcomes from continuous or discrete inputs. In the Bayesian framework, we treat classification parameters as random variables with prior distributions, enabling natural uncertainty quantification and robust model comparison.

graph TB
    A["Labeled Data: (x₁,y₁), (x₂,y₂), ..., (xₙ,yₙ)"] --> B["Classification Model<br/>P(y|x, θ)"]

    B --> C["Bayesian Framework"]
    C --> D["Prior: p(θ)"]
    C --> E["Likelihood: p(y|X, θ)"]

    D --> F["Posterior: p(θ|y, X)"]
    E --> F

    F --> G["MCMC Sampling"]
    G --> H["Parameter Uncertainty"]
    G --> I["Prediction Probabilities"]
    G --> J["Model Comparison"]

Advantages of Bayesian Classification

Traditional machine learning gives you point predictions. Bayesian classification provides:

  • Posterior probability distributions over class labels
  • Uncertainty estimates for each prediction
  • Principled model comparison using marginal likelihoods
  • Automatic regularization through informative priors

Binary Classification: Logistic Regression

The foundation of Bayesian classification is logistic regression, which models the probability of binary outcomes.

Mathematical Model

For binary classification, we model:

Where:

  • is the binary outcome
  • is the probability of class 1
  • is the log-odds
  • are the regression coefficients

Implementation

// Basic Bayesian logistic regression model
fn logistic_regression_model(features: Vec<Vec<f64>>, labels: Vec<bool>) -> Model<Vec<f64>> {
    let n_features = features[0].len();

    prob! {
        // Sample coefficients with regularizing priors - build using plate
        let coefficients <- plate!(i in 0..n_features => {
            sample(addr!("beta", i), fugue::Normal::new(0.0, 2.0).unwrap())
        });

        // Clone coefficients for use in closure
        let coefficients_for_obs = coefficients.clone();
        let _observations <- plate!(obs_idx in features.iter().zip(labels.iter()).enumerate() => {
            let (idx, (x_vec, &y)) = obs_idx;
            // Compute linear predictor (log-odds)
            let mut linear_pred = 0.0;
            for (coef, &x_val) in coefficients_for_obs.iter().zip(x_vec.iter()) {
                linear_pred += coef * x_val;
            }

            // Convert to probability using logistic function
            let prob = 1.0 / (1.0 + { -linear_pred }.exp());

            // Ensure probability is in valid range
            let bounded_prob = prob.clamp(1e-10, 1.0 - 1e-10);

            // Observe the binary outcome
            observe(addr!("y", idx), Bernoulli::new(bounded_prob).unwrap(), y)
        });

        pure(coefficients)
    }
}

fn binary_classification_demo() {
    println!("=== Binary Classification with Logistic Regression ===\n");

    // Generate synthetic data
    let (features, labels) = generate_classification_data(100, 42);
    let positive_cases = labels.iter().filter(|&&x| x).count();

    println!("📊 Generated {} data points", features.len());
    println!("   - Features: {} dimensions", features[0].len());
    println!(
        "   - Positive cases: {} / {} ({:.1}%)",
        positive_cases,
        labels.len(),
        100.0 * positive_cases as f64 / labels.len() as f64
    );
    println!("   - True coefficients: intercept=-1.0, β₁=2.0, β₂=-1.5");

    // Run MCMC inference
    let model_fn = move || logistic_regression_model(features.clone(), labels.clone());
    let mut rng = StdRng::seed_from_u64(12345);

    println!("\n🔬 Running MCMC inference...");
    let samples = adaptive_mcmc_chain(&mut rng, model_fn, 800, 200);

    // Extract coefficient estimates
    let valid_samples: Vec<_> = samples
        .iter()
        .filter_map(|(coeffs, trace)| {
            if trace.total_log_weight().is_finite() {
                Some(coeffs)
            } else {
                None
            }
        })
        .collect();

    if !valid_samples.is_empty() {
        println!(
            "✅ MCMC completed with {} valid samples",
            valid_samples.len()
        );
        println!("\n📈 Coefficient Estimates:");

        let coef_names = ["Intercept", "β₁ (feature 1)", "β₂ (feature 2)"];
        let true_coefs = [-1.0, 2.0, -1.5];

        for (i, (name, true_val)) in coef_names.iter().zip(true_coefs.iter()).enumerate() {
            let coef_samples: Vec<f64> = valid_samples.iter().map(|coeffs| coeffs[i]).collect();

            let mean_coef = coef_samples.iter().sum::<f64>() / coef_samples.len() as f64;
            let std_coef = {
                let variance = coef_samples
                    .iter()
                    .map(|c| (c - mean_coef).powi(2))
                    .sum::<f64>()
                    / (coef_samples.len() - 1) as f64;
                variance.sqrt()
            };

            println!(
                "   - {}: {:.3} ± {:.3} (true: {:.1})",
                name, mean_coef, std_coef, true_val
            );
        }

        // Model diagnostics
        let avg_log_weight = samples
            .iter()
            .map(|(_, trace)| trace.total_log_weight())
            .filter(|w| w.is_finite())
            .sum::<f64>()
            / valid_samples.len() as f64;

        println!("   - Average log-likelihood: {:.2}", avg_log_weight);

        // Make predictions on new data
        println!("\n🔮 Prediction Example:");
        let test_features = [1.0, 0.5, -0.8]; // New observation
        let mut predicted_probs = Vec::new();

        for coeffs in valid_samples.iter().take(50) {
            // Use subset for speed
            let mut linear_pred = 0.0;
            for (coef, &x_val) in coeffs.iter().zip(test_features.iter()) {
                linear_pred += coef * x_val;
            }
            let prob = 1.0 / (1.0 + (-linear_pred).exp());
            predicted_probs.push(prob);
        }

        let mean_prob = predicted_probs.iter().sum::<f64>() / predicted_probs.len() as f64;
        let std_prob = {
            let variance = predicted_probs
                .iter()
                .map(|p| (p - mean_prob).powi(2))
                .sum::<f64>()
                / (predicted_probs.len() - 1) as f64;
            variance.sqrt()
        };

        println!(
            "   - Test point [0.5, -0.8]: P(y=1) = {:.3} ± {:.3}",
            mean_prob, std_prob
        );
        if mean_prob > 0.5 {
            println!("   - Prediction: Class 1 (probability > 0.5)");
        } else {
            println!("   - Prediction: Class 0 (probability < 0.5)");
        }
    } else {
        println!("❌ No valid MCMC samples obtained");
    }

    println!();
}

Key Features

  • Automatic constraint handling: Our improved MCMC properly handles the logistic transformation
  • Interpretable coefficients: Each represents log-odds ratios
  • Natural uncertainty: Posterior samples give prediction intervals

Logistic Regression Interpretation

  • Coefficient : feature increases log-odds of class 1
  • Coefficient : feature decreases log-odds of class 1
  • gives the odds ratio for a unit change in
  • Use standardized features for coefficient comparability

Multi-class Classification: Multinomial Logit

For problems with more than two classes, we use multinomial logistic regression.

Mathematical Model

For classes, we model:

The last class () serves as the reference category.

Implementation

// Multinomial logistic regression for multi-class classification
// Note: This is a simplified version - full multinomial requires more complex implementation
fn multiclass_classification_demo() {
    println!("=== Multi-class Classification (Conceptual) ===\n");

    let (features, labels) = generate_multiclass_data(150, 3, 1337);

    println!("📊 Generated {} data points", features.len());
    println!("   - {} classes", 3);
    println!("   - Features: {} dimensions", features[0].len());

    // Count class distribution
    let mut class_counts = [0; 3];
    for &label in &labels {
        class_counts[label] += 1;
    }

    for (class_id, count) in class_counts.iter().enumerate() {
        println!(
            "   - Class {}: {} samples ({:.1}%)",
            class_id,
            count,
            100.0 * *count as f64 / labels.len() as f64
        );
    }

    println!("\n💡 Multinomial Classification Concepts:");
    println!("   - Uses K-1 sets of coefficients (reference category approach)");
    println!("   - Each coefficient set models log(P(class_k) / P(class_reference))");
    println!("   - Probabilities sum to 1 via softmax transformation");
    println!("   - More complex to implement but follows same Bayesian principles");

    // For now, demonstrate the concept with binary classification on each class
    println!("\n🔬 One-vs-Rest Classification (simplified approach):");

    for target_class in 0..3 {
        // Convert to binary problem: target_class vs. all others
        let binary_labels: Vec<bool> = labels.iter().map(|&label| label == target_class).collect();

        let positive_cases = binary_labels.iter().filter(|&&x| x).count();

        println!("\n   Class {} vs Rest:", target_class);
        println!(
            "   - Positive cases: {} / {}",
            positive_cases,
            binary_labels.len()
        );

        // Clone data for each iteration to avoid move issues
        let features_copy = features.clone();
        let model_fn =
            move || logistic_regression_model(features_copy.clone(), binary_labels.clone());
        let mut rng = StdRng::seed_from_u64(1000 + target_class as u64);

        let samples = adaptive_mcmc_chain(&mut rng, model_fn, 300, 60);
        let valid_samples = samples.len();

        if valid_samples > 0 {
            println!("   - MCMC: {} samples obtained", valid_samples);
        }
    }

    println!("\n💭 Note: Full multinomial logistic regression requires implementing");
    println!("   the softmax link function and careful handling of identifiability constraints.");
    println!();
}

Hierarchical Classification

When your data has group structure (e.g., students within schools, patients within hospitals), hierarchical models can improve predictions by sharing information across groups.

Mathematical Model

Where:

  • indexes individuals, indexes groups
  • are group-specific intercepts
  • control how much groups can vary

Implementation

// Hierarchical logistic regression with group-level effects
fn hierarchical_classification_model(
    features: Vec<Vec<f64>>,
    labels: Vec<bool>,
    groups: Vec<usize>,
) -> Model<(f64, f64, Vec<f64>)> {
    let n_groups = groups.iter().max().unwrap_or(&0) + 1;

    prob! {
        // Global parameters
        let global_intercept <- sample(addr!("global_intercept"), fugue::Normal::new(0.0, 2.0).unwrap());
        let slope <- sample(addr!("slope"), fugue::Normal::new(0.0, 2.0).unwrap());

        // Group-level variance
        let group_sigma <- sample(addr!("group_sigma"), Gamma::new(1.0, 1.0).unwrap());

        // Group-specific intercepts using plate notation
        let group_intercepts <- plate!(g in 0..n_groups => {
            sample(addr!("group_intercept", g), fugue::Normal::new(global_intercept, group_sigma).unwrap())
        });

        // Clone group_intercepts for use in closure
        let group_intercepts_for_obs = group_intercepts.clone();
        let _observations <- plate!(data in features.iter()
            .map(|f| f[1]) // Extract the single feature (after intercept)
            .zip(labels.iter())
            .zip(groups.iter())
            .enumerate() => {
            let (obs_idx, ((x_val, &y), &group_id)) = data;
            let linear_pred = group_intercepts_for_obs[group_id] + slope * x_val;
            let prob = 1.0 / (1.0 + { -linear_pred }.exp());
            let bounded_prob = prob.clamp(1e-10, 1.0 - 1e-10);

            observe(addr!("obs", obs_idx), Bernoulli::new(bounded_prob).unwrap(), y)
        });

        pure((global_intercept, slope, group_intercepts))
    }
}

fn hierarchical_classification_demo() {
    println!("=== Hierarchical Classification ===\n");

    let (features, labels, groups) = generate_hierarchical_data(4, 25, 5678);
    let n_groups = groups.iter().max().unwrap() + 1;

    println!("📊 Generated hierarchical data:");
    println!(
        "   - {} groups with {} observations each",
        n_groups,
        features.len() / n_groups
    );
    println!("   - Total: {} data points", features.len());

    // Show group-wise statistics
    for group_id in 0..n_groups {
        let group_labels: Vec<bool> = groups
            .iter()
            .zip(labels.iter())
            .filter_map(|(&g, &y)| if g == group_id { Some(y) } else { None })
            .collect();

        let positive_rate =
            group_labels.iter().filter(|&&x| x).count() as f64 / group_labels.len() as f64;
        println!(
            "   - Group {}: {:.1}% positive cases",
            group_id,
            positive_rate * 100.0
        );
    }

    println!("\n🔬 Running hierarchical MCMC...");
    let model_fn =
        move || hierarchical_classification_model(features.clone(), labels.clone(), groups.clone());
    let mut rng = StdRng::seed_from_u64(9999);
    let samples = adaptive_mcmc_chain(&mut rng, model_fn, 600, 150);

    let valid_samples: Vec<_> = samples
        .iter()
        .filter(|(_, trace)| trace.total_log_weight().is_finite())
        .collect();

    if !valid_samples.is_empty() {
        println!(
            "✅ Hierarchical MCMC completed with {} valid samples",
            valid_samples.len()
        );

        // Extract global parameters
        let global_intercepts: Vec<f64> =
            valid_samples.iter().map(|(params, _)| params.0).collect();
        let slopes: Vec<f64> = valid_samples.iter().map(|(params, _)| params.1).collect();

        let mean_global_int =
            global_intercepts.iter().sum::<f64>() / global_intercepts.len() as f64;
        let mean_slope = slopes.iter().sum::<f64>() / slopes.len() as f64;

        println!("\n📈 Global Parameter Estimates:");
        println!("   - Global intercept: {:.3} (true: ~0.0)", mean_global_int);
        println!("   - Slope: {:.3} (true: 1.5)", mean_slope);

        // Extract group-specific intercepts
        println!("\n🏘️  Group-Specific Intercepts:");
        for group_id in 0..n_groups {
            let group_intercepts: Vec<f64> = valid_samples
                .iter()
                .map(|(params, _)| params.2[group_id])
                .collect();

            let mean_group_int =
                group_intercepts.iter().sum::<f64>() / group_intercepts.len() as f64;
            println!("   - Group {}: {:.3}", group_id, mean_group_int);
        }

        println!("\n💡 Hierarchical Benefits:");
        println!("   - Groups share information through global parameters");
        println!("   - Individual groups can have their own intercepts");
        println!("   - Better predictions for groups with less data");
        println!("   - Automatic regularization through group-level priors");
    } else {
        println!("❌ No valid hierarchical samples obtained");
    }

    println!();
}

Model Comparison and Selection

Bayesian methods provide principled approaches to comparing models:

Deviance Information Criterion (DIC)

DIC balances model fit against complexity:

Where is average deviance and is effective parameters.

Widely Applicable Information Criterion (WAIC)

WAIC is a more robust alternative:

Implementation

// Simple model comparison using log-likelihood
fn model_comparison_demo() {
    println!("=== Model Comparison ===\n");

    let (features, labels) = generate_classification_data(80, 2021);
    let _features_ref = &features;
    let _labels_ref = &labels;

    println!("📊 Comparing different logistic regression models:");
    println!("   - Model 1: Intercept only");
    println!("   - Model 2: Intercept + Feature 1");
    println!("   - Model 3: Full model (Intercept + Feature 1 + Feature 2)");

    struct ModelResult {
        name: String,
        n_params: usize,
        log_likelihood: f64,
        samples: usize,
    }

    let mut results = Vec::new();

    // Model 1: Intercept only
    {
        let intercept_features: Vec<Vec<f64>> = features
            .iter()
            .map(|f| vec![f[0]]) // Just intercept
            .collect();
        let labels_clone = labels.clone();

        let model_fn =
            move || logistic_regression_model(intercept_features.clone(), labels_clone.clone());
        let mut rng = StdRng::seed_from_u64(1111);
        let samples = adaptive_mcmc_chain(&mut rng, model_fn, 300, 80);

        let valid_samples: Vec<_> = samples
            .iter()
            .filter(|(_, trace)| trace.total_log_weight().is_finite())
            .collect();

        if !valid_samples.is_empty() {
            let avg_log_lik = valid_samples
                .iter()
                .map(|(_, trace)| trace.total_log_weight())
                .sum::<f64>()
                / valid_samples.len() as f64;

            results.push(ModelResult {
                name: "Intercept only".to_string(),
                n_params: 1,
                log_likelihood: avg_log_lik,
                samples: valid_samples.len(),
            });
        }
    }

    // Model 2: Intercept + Feature 1
    {
        let reduced_features: Vec<Vec<f64>> = features
            .iter()
            .map(|f| vec![f[0], f[1]]) // Intercept + first feature
            .collect();
        let labels_clone = labels.clone();

        let model_fn =
            move || logistic_regression_model(reduced_features.clone(), labels_clone.clone());
        let mut rng = StdRng::seed_from_u64(2222);
        let samples = adaptive_mcmc_chain(&mut rng, model_fn, 300, 80);

        let valid_samples: Vec<_> = samples
            .iter()
            .filter(|(_, trace)| trace.total_log_weight().is_finite())
            .collect();

        if !valid_samples.is_empty() {
            let avg_log_lik = valid_samples
                .iter()
                .map(|(_, trace)| trace.total_log_weight())
                .sum::<f64>()
                / valid_samples.len() as f64;

            results.push(ModelResult {
                name: "Intercept + Feature 1".to_string(),
                n_params: 2,
                log_likelihood: avg_log_lik,
                samples: valid_samples.len(),
            });
        }
    }

    // Model 3: Full model
    {
        let labels_clone = labels.clone();
        let model_fn = move || logistic_regression_model(features.clone(), labels_clone.clone());
        let mut rng = StdRng::seed_from_u64(3333);
        let samples = adaptive_mcmc_chain(&mut rng, model_fn, 300, 80);

        let valid_samples: Vec<_> = samples
            .iter()
            .filter(|(_, trace)| trace.total_log_weight().is_finite())
            .collect();

        if !valid_samples.is_empty() {
            let avg_log_lik = valid_samples
                .iter()
                .map(|(_, trace)| trace.total_log_weight())
                .sum::<f64>()
                / valid_samples.len() as f64;

            results.push(ModelResult {
                name: "Full model".to_string(),
                n_params: 3,
                log_likelihood: avg_log_lik,
                samples: valid_samples.len(),
            });
        }
    }

    if !results.is_empty() {
        println!("\n🏆 Model Comparison Results:");
        println!("   Model                    | Params | Log-Likelihood | Samples");
        println!("   -------------------------|--------|----------------|--------");

        for result in &results {
            println!(
                "   {:24} | {:6} | {:14.2} | {:7}",
                result.name, result.n_params, result.log_likelihood, result.samples
            );
        }

        // Find best model
        if let Some(best) = results
            .iter()
            .max_by(|a, b| a.log_likelihood.partial_cmp(&b.log_likelihood).unwrap())
        {
            println!("\n🥇 Best Model: {} (highest log-likelihood)", best.name);
        }

        println!("\n💡 Model Selection Notes:");
        println!("   - Higher log-likelihood indicates better fit to data");
        println!("   - In practice, use information criteria (AIC, BIC, WAIC)");
        println!("   - These account for model complexity to prevent overfitting");
        println!("   - Cross-validation provides robust model comparison");
    } else {
        println!("❌ Model comparison failed - no valid samples obtained");
    }

    println!();
}

Practical Considerations

Feature Engineering

Effective classification often requires thoughtful feature engineering:

let x1 = 0.5; let x2 = 0.8; let category = "A";
// Polynomial features
let x2_squared = x1 * x1;
let x1_x2_interaction = x1 * x2;

// Categorical encoding (one-hot)
let is_category_a = if category == "A" { 1.0 } else { 0.0 };

Handling Class Imbalance

For imbalanced datasets, consider:

  • Weighted priors: Give more weight to rare classes
  • Threshold tuning: Optimize classification thresholds
  • Stratified sampling: Ensure balanced training data

Computational Considerations

  • Start simple: Begin with basic logistic regression
  • Check convergence: Monitor R-hat and effective sample size
  • Scale features: Standardize continuous predictors
  • Use constraints: Let Fugue's constraint-aware MCMC handle bounded parameters

MCMC for Classification

Classification models can be challenging for MCMC due to:

  • Separation: Perfect classification can lead to infinite parameter estimates
  • Weak identification: Sparse data in some classes affects convergence
  • Constraint handling: Probabilities must sum to 1 in multinomial models

Use regularizing priors and check diagnostics carefully.

Performance Evaluation

Metrics for Binary Classification

let tp = 10.0; let tn = 20.0; let fp = 5.0; let fn_count = 3.0;
// Accuracy, Precision, Recall, F1-score
let accuracy = (tp + tn) / (tp + tn + fp + fn_count);
let precision = tp / (tp + fp);
let recall = tp / (tp + fn_count);
let f1 = 2.0 * precision * recall / (precision + recall);

Bayesian Evaluation

Unlike traditional ML, Bayesian methods naturally provide:

  • Credible intervals for all metrics
  • Prediction intervals for new observations
  • Model uncertainty via posterior model probabilities

Advanced Extensions

Ordinal Classification

For ordered categorical outcomes (e.g., ratings, severity levels):

use fugue::*;

// Ordinal logistic regression with proportional odds
fn ordinal_classification_model(
    features: Vec<Vec<f64>>,
    outcomes: Vec<usize>, // 0, 1, 2, ..., K-1
    n_categories: usize
) -> Model<(Vec<f64>, Vec<f64>)> {
    prob! {
        // Regression coefficients (shared across categories)
        let coefficients <- plate!(i in 0..features[0].len() => {
            sample(addr!("beta", i), fugue::Normal::new(0.0, 2.0).unwrap())
        });
        
        // Cutpoints (must be ordered)
        let mut cutpoints = Vec::new();
        let first_cut <- sample(addr!("cutpoint", 0), fugue::Normal::new(0.0, 5.0).unwrap());
        cutpoints.push(first_cut);
        
        for k in 1..(n_categories-1) {
            let delta <- sample(addr!("delta", k), Gamma::new(1.0, 1.0).unwrap());
            cutpoints.push(cutpoints[k-1] + delta);
        }
        
        // Likelihood using cumulative logits
        let _observations <- plate!(obs_idx in features.iter().zip(outcomes.iter()).enumerate() => {
            let (idx, (x_vec, &y)) = obs_idx;
            let mut linear_pred = 0.0;
            for (coef, &x_val) in coefficients.iter().zip(x_vec.iter()) {
                linear_pred += coef * x_val;
            }
            
            // Compute category probabilities
            let mut probs = Vec::new();
            for k in 0..n_categories {
                let prob = if k == 0 {
                    1.0 / (1.0 + (-(cutpoints[0] - linear_pred)).exp())
                } else if k == n_categories - 1 {
                    1.0 - (1.0 / (1.0 + (-(cutpoints[k-1] - linear_pred)).exp()))
                } else {
                    let p_le_k = 1.0 / (1.0 + (-(cutpoints[k] - linear_pred)).exp());
                    let p_le_k_minus_1 = 1.0 / (1.0 + (-(cutpoints[k-1] - linear_pred)).exp());
                    p_le_k - p_le_k_minus_1
                };
                probs.push(prob.max(1e-10).min(1.0 - 1e-10));
            }
            
            observe(addr!("y", idx), Categorical::new(probs).unwrap(), y)
        });
        
        pure((coefficients, cutpoints))
    }
}

Robust Classification

Handle outliers using heavy-tailed link functions:

use fugue::*;

// Robust logistic regression with t-distributed errors
fn robust_classification_model(
    features: Vec<Vec<f64>>,
    labels: Vec<bool>
) -> Model<(Vec<f64>, f64)> {
    prob! {
        // Coefficients
        let coefficients <- plate!(i in 0..features[0].len() => {
            sample(addr!("beta", i), fugue::Normal::new(0.0, 2.0).unwrap())
        });
        
        // Degrees of freedom for robustness
        let nu <- sample(addr!("nu"), Gamma::new(2.0, 0.1).unwrap());
        
        // Robust likelihood using latent variables
        let _observations <- plate!(obs_idx in features.iter().zip(labels.iter()).enumerate() => {
            let (idx, (x_vec, &y)) = obs_idx;
            
            // Linear predictor
            let mut eta = 0.0;
            for (coef, &x_val) in coefficients.iter().zip(x_vec.iter()) {
                eta += coef * x_val;
            }
            
            // Latent variable for robustness
            let z <- sample(addr!("z", idx), fugue::Normal::new(eta, 1.0).unwrap());
            
            // Robust transformation
            let p = 1.0 / (1.0 + (-z).exp());
            let bounded_p = p.max(1e-10).min(1.0 - 1e-10);
            
            observe(addr!("y", idx), Bernoulli::new(bounded_p).unwrap(), y)
        });
        
        pure((coefficients, nu))
    }
}

Production Considerations

Scalability

For large datasets, consider:

  1. Mini-batch MCMC: Process data in chunks for memory efficiency
  2. Variational Inference: Approximate posteriors for faster computation
  3. Sparse Models: Use regularization for high-dimensional feature spaces
  4. GPU Acceleration: Vectorized operations for matrix computations

Production Deployment

  • Monitor convergence: Set up automated R-hat checking
  • Prediction pipelines: Cache MCMC samples for fast inference
  • Model updating: Implement online learning for streaming data
  • A/B testing: Use Bayesian methods for experiment analysis

Model Diagnostics

Essential checks for classification models:

use fugue::inference::diagnostics::*;

fn classification_diagnostics(
    samples: &[Vec<f64>], 
    features: &[Vec<f64>], 
    labels: &[bool]
) {
    // Compute prediction accuracy
    let predictions: Vec<bool> = features.iter().enumerate().map(|(i, x_vec)| {
        let prob: f64 = samples.iter().map(|coeffs| {
            let linear_pred = coeffs.iter().zip(x_vec.iter())
                .map(|(coef, x)| coef * x).sum::<f64>();
            1.0 / (1.0 + (-linear_pred).exp())
        }).sum::<f64>() / samples.len() as f64;
        
        prob > 0.5
    }).collect();
    
    // Classification metrics
    let tp = predictions.iter().zip(labels.iter())
        .filter(|(&pred, &actual)| pred && actual).count();
    let tn = predictions.iter().zip(labels.iter())
        .filter(|(&pred, &actual)| !pred && !actual).count();
    let fp = predictions.iter().zip(labels.iter())
        .filter(|(&pred, &actual)| pred && !actual).count();
    let fn_ = predictions.iter().zip(labels.iter())
        .filter(|(&pred, &actual)| !pred && actual).count();
    
    let accuracy = (tp + tn) as f64 / labels.len() as f64;
    let precision = tp as f64 / (tp + fp) as f64;
    let recall = tp as f64 / (tp + fn_) as f64;
    
    println!("Classification Diagnostics:");
    println!("  Accuracy: {:.3}", accuracy);
    println!("  Precision: {:.3}", precision);
    println!("  Recall: {:.3}", recall);
    println!("  F1-Score: {:.3}", 2.0 * precision * recall / (precision + recall));
}

Running the Examples

To explore these classification techniques:

# Run the classification demonstrations
cargo run --example classification

# Run specific tests
cargo test --example classification

# Build documentation with examples
mdbook build docs/

Key Takeaways

Classification Mastery

  1. Bayesian Advantage: Natural uncertainty quantification through posterior distributions
  2. Model Flexibility: Handle binary, multi-class, ordinal, and hierarchical outcomes
  3. Robust Methods: Constraint-aware MCMC prevents numerical issues
  4. Principled Selection: Use information criteria and Bayes factors for model choice
  5. Production Ready: Scalable workflows with proper diagnostics and validation
  6. Real-World Applications: Flexible framework for diverse classification problems

Core Techniques:

  • Binary Classification with logistic regression and uncertainty
  • Multi-class Methods using multinomial and one-vs-rest approaches
  • Hierarchical Models for grouped and nested data structures
  • Model Comparison with information criteria and Bayes factors
  • Robust Extensions for outlier resistance and stability
  • Production Deployment with monitoring and scalable inference

Further Reading

  • Building Complex Models - Advanced modeling techniques
  • Optimizing Performance - Scalable inference strategies
  • Hierarchical Models - Advanced multilevel modeling
  • Mixture Models - Unsupervised classification and clustering
  • Time Series - Classification with temporal structure
  • Gelman et al. "Bayesian Data Analysis" - Comprehensive statistical reference
  • McElreath "Statistical Rethinking" - Modern Bayesian approach
  • Kruschke "Doing Bayesian Data Analysis" - Applied Bayesian methods

The combination of Fugue's type-safe probabilistic programming and constraint-aware MCMC makes Bayesian classification both theoretically principled and computationally practical. The natural uncertainty quantification provides insights that traditional point estimates cannot match.