pymc-bayesian-modeling
This skill provides a detailed guide for Bayesian modeling using PyMC, covering the full workflow from data preparation to model validation. It includes practical code examples for building hierarchical models, MCMC sampling, variational inference, and model comparison techniques like LOO/WAIC. The documentation addresses common pitfalls like divergences and offers reparameterization solutions.
Packaged view
This page reorganizes the original catalog entry around fit, installability, and workflow context first. The original raw source lives below.
Install command
npx @skill-hub/cli install microck-ordinary-claude-skills-pymc
Repository
Skill path: skills_all/claude-scientific-skills/scientific-skills/pymc
This skill provides a detailed guide for Bayesian modeling using PyMC, covering the full workflow from data preparation to model validation. It includes practical code examples for building hierarchical models, MCMC sampling, variational inference, and model comparison techniques like LOO/WAIC. The documentation addresses common pitfalls like divergences and offers reparameterization solutions.
Open repositoryBest for
Primary workflow: Analyze Data & AI.
Technical facets: Data / AI, Full Stack.
Target audience: Data scientists and researchers with intermediate Python skills who need to implement Bayesian models for statistical inference, uncertainty quantification, or hierarchical modeling tasks..
License: Unknown.
Original source
Catalog source: SkillHub Club.
Repository owner: Microck.
This is still a mirrored public skill entry. Review the repository before installing into production workflows.
What it helps with
- Install pymc-bayesian-modeling into Claude Code, Codex CLI, Gemini CLI, or OpenCode workflows
- Review https://github.com/Microck/ordinary-claude-skills before adding pymc-bayesian-modeling to shared team environments
- Use pymc-bayesian-modeling for ai/ml workflows
Works across
Favorites: 0.
Sub-skills: 0.
Aggregator: No.
Original source / Raw SKILL.md
---
name: pymc-bayesian-modeling
description: "Bayesian modeling with PyMC. Build hierarchical models, MCMC (NUTS), variational inference, LOO/WAIC comparison, posterior checks, for probabilistic programming and inference."
---
# PyMC Bayesian Modeling
## Overview
PyMC is a Python library for Bayesian modeling and probabilistic programming. Build, fit, validate, and compare Bayesian models using PyMC's modern API (version 5.x+), including hierarchical models, MCMC sampling (NUTS), variational inference, and model comparison (LOO, WAIC).
## When to Use This Skill
This skill should be used when:
- Building Bayesian models (linear/logistic regression, hierarchical models, time series, etc.)
- Performing MCMC sampling or variational inference
- Conducting prior/posterior predictive checks
- Diagnosing sampling issues (divergences, convergence, ESS)
- Comparing multiple models using information criteria (LOO, WAIC)
- Implementing uncertainty quantification through Bayesian methods
- Working with hierarchical/multilevel data structures
- Handling missing data or measurement error in a principled way
## Standard Bayesian Workflow
Follow this workflow for building and validating Bayesian models:
### 1. Data Preparation
```python
import pymc as pm
import arviz as az
import numpy as np
# Load and prepare data
X = ... # Predictors
y = ... # Outcomes
# Standardize predictors for better sampling
X_mean = X.mean(axis=0)
X_std = X.std(axis=0)
X_scaled = (X - X_mean) / X_std
```
**Key practices:**
- Standardize continuous predictors (improves sampling efficiency)
- Center outcomes when possible
- Handle missing data explicitly (treat as parameters)
- Use named dimensions with `coords` for clarity
### 2. Model Building
```python
coords = {
'predictors': ['var1', 'var2', 'var3'],
'obs_id': np.arange(len(y))
}
with pm.Model(coords=coords) as model:
# Priors
alpha = pm.Normal('alpha', mu=0, sigma=1)
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
sigma = pm.HalfNormal('sigma', sigma=1)
# Linear predictor
mu = alpha + pm.math.dot(X_scaled, beta)
# Likelihood
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs_id')
```
**Key practices:**
- Use weakly informative priors (not flat priors)
- Use `HalfNormal` or `Exponential` for scale parameters
- Use named dimensions (`dims`) instead of `shape` when possible
- Use `pm.Data()` for values that will be updated for predictions
### 3. Prior Predictive Check
**Always validate priors before fitting:**
```python
with model:
prior_pred = pm.sample_prior_predictive(samples=1000, random_seed=42)
# Visualize
az.plot_ppc(prior_pred, group='prior')
```
**Check:**
- Do prior predictions span reasonable values?
- Are extreme values plausible given domain knowledge?
- If priors generate implausible data, adjust and re-check
### 4. Fit Model
```python
with model:
# Optional: Quick exploration with ADVI
# approx = pm.fit(n=20000)
# Full MCMC inference
idata = pm.sample(
draws=2000,
tune=1000,
chains=4,
target_accept=0.9,
random_seed=42,
idata_kwargs={'log_likelihood': True} # For model comparison
)
```
**Key parameters:**
- `draws=2000`: Number of samples per chain
- `tune=1000`: Warmup samples (discarded)
- `chains=4`: Run 4 chains for convergence checking
- `target_accept=0.9`: Higher for difficult posteriors (0.95-0.99)
- Include `log_likelihood=True` for model comparison
### 5. Check Diagnostics
**Use the diagnostic script:**
```python
from scripts.model_diagnostics import check_diagnostics
results = check_diagnostics(idata, var_names=['alpha', 'beta', 'sigma'])
```
**Check:**
- **R-hat < 1.01**: Chains have converged
- **ESS > 400**: Sufficient effective samples
- **No divergences**: NUTS sampled successfully
- **Trace plots**: Chains should mix well (fuzzy caterpillar)
**If issues arise:**
- Divergences → Increase `target_accept=0.95`, use non-centered parameterization
- Low ESS → Sample more draws, reparameterize to reduce correlation
- High R-hat → Run longer, check for multimodality
### 6. Posterior Predictive Check
**Validate model fit:**
```python
with model:
pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42)
# Visualize
az.plot_ppc(idata)
```
**Check:**
- Do posterior predictions capture observed data patterns?
- Are systematic deviations evident (model misspecification)?
- Consider alternative models if fit is poor
### 7. Analyze Results
```python
# Summary statistics
print(az.summary(idata, var_names=['alpha', 'beta', 'sigma']))
# Posterior distributions
az.plot_posterior(idata, var_names=['alpha', 'beta', 'sigma'])
# Coefficient estimates
az.plot_forest(idata, var_names=['beta'], combined=True)
```
### 8. Make Predictions
```python
X_new = ... # New predictor values
X_new_scaled = (X_new - X_mean) / X_std
with model:
pm.set_data({'X_scaled': X_new_scaled})
post_pred = pm.sample_posterior_predictive(
idata.posterior,
var_names=['y_obs'],
random_seed=42
)
# Extract prediction intervals
y_pred_mean = post_pred.posterior_predictive['y_obs'].mean(dim=['chain', 'draw'])
y_pred_hdi = az.hdi(post_pred.posterior_predictive, var_names=['y_obs'])
```
## Common Model Patterns
### Linear Regression
For continuous outcomes with linear relationships:
```python
with pm.Model() as linear_model:
alpha = pm.Normal('alpha', mu=0, sigma=10)
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
sigma = pm.HalfNormal('sigma', sigma=1)
mu = alpha + pm.math.dot(X, beta)
y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs)
```
**Use template:** `assets/linear_regression_template.py`
### Logistic Regression
For binary outcomes:
```python
with pm.Model() as logistic_model:
alpha = pm.Normal('alpha', mu=0, sigma=10)
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
logit_p = alpha + pm.math.dot(X, beta)
y = pm.Bernoulli('y', logit_p=logit_p, observed=y_obs)
```
### Hierarchical Models
For grouped data (use non-centered parameterization):
```python
with pm.Model(coords={'groups': group_names}) as hierarchical_model:
# Hyperpriors
mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=10)
sigma_alpha = pm.HalfNormal('sigma_alpha', sigma=1)
# Group-level (non-centered)
alpha_offset = pm.Normal('alpha_offset', mu=0, sigma=1, dims='groups')
alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_offset, dims='groups')
# Observation-level
mu = alpha[group_idx]
sigma = pm.HalfNormal('sigma', sigma=1)
y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs)
```
**Use template:** `assets/hierarchical_model_template.py`
**Critical:** Always use non-centered parameterization for hierarchical models to avoid divergences.
### Poisson Regression
For count data:
```python
with pm.Model() as poisson_model:
alpha = pm.Normal('alpha', mu=0, sigma=10)
beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
log_lambda = alpha + pm.math.dot(X, beta)
y = pm.Poisson('y', mu=pm.math.exp(log_lambda), observed=y_obs)
```
For overdispersed counts, use `NegativeBinomial` instead.
### Time Series
For autoregressive processes:
```python
with pm.Model() as ar_model:
sigma = pm.HalfNormal('sigma', sigma=1)
rho = pm.Normal('rho', mu=0, sigma=0.5, shape=ar_order)
init_dist = pm.Normal.dist(mu=0, sigma=sigma)
y = pm.AR('y', rho=rho, sigma=sigma, init_dist=init_dist, observed=y_obs)
```
## Model Comparison
### Comparing Models
Use LOO or WAIC for model comparison:
```python
from scripts.model_comparison import compare_models, check_loo_reliability
# Fit models with log_likelihood
models = {
'Model1': idata1,
'Model2': idata2,
'Model3': idata3
}
# Compare using LOO
comparison = compare_models(models, ic='loo')
# Check reliability
check_loo_reliability(models)
```
**Interpretation:**
- **Δloo < 2**: Models are similar, choose simpler model
- **2 < Δloo < 4**: Weak evidence for better model
- **4 < Δloo < 10**: Moderate evidence
- **Δloo > 10**: Strong evidence for better model
**Check Pareto-k values:**
- k < 0.7: LOO reliable
- k > 0.7: Consider WAIC or k-fold CV
### Model Averaging
When models are similar, average predictions:
```python
from scripts.model_comparison import model_averaging
averaged_pred, weights = model_averaging(models, var_name='y_obs')
```
## Distribution Selection Guide
### For Priors
**Scale parameters** (σ, τ):
- `pm.HalfNormal('sigma', sigma=1)` - Default choice
- `pm.Exponential('sigma', lam=1)` - Alternative
- `pm.Gamma('sigma', alpha=2, beta=1)` - More informative
**Unbounded parameters**:
- `pm.Normal('theta', mu=0, sigma=1)` - For standardized data
- `pm.StudentT('theta', nu=3, mu=0, sigma=1)` - Robust to outliers
**Positive parameters**:
- `pm.LogNormal('theta', mu=0, sigma=1)`
- `pm.Gamma('theta', alpha=2, beta=1)`
**Probabilities**:
- `pm.Beta('p', alpha=2, beta=2)` - Weakly informative
- `pm.Uniform('p', lower=0, upper=1)` - Non-informative (use sparingly)
**Correlation matrices**:
- `pm.LKJCorr('corr', n=n_vars, eta=2)` - eta=1 uniform, eta>1 prefers identity
### For Likelihoods
**Continuous outcomes**:
- `pm.Normal('y', mu=mu, sigma=sigma)` - Default for continuous data
- `pm.StudentT('y', nu=nu, mu=mu, sigma=sigma)` - Robust to outliers
**Count data**:
- `pm.Poisson('y', mu=lambda)` - Equidispersed counts
- `pm.NegativeBinomial('y', mu=mu, alpha=alpha)` - Overdispersed counts
- `pm.ZeroInflatedPoisson('y', psi=psi, mu=mu)` - Excess zeros
**Binary outcomes**:
- `pm.Bernoulli('y', p=p)` or `pm.Bernoulli('y', logit_p=logit_p)`
**Categorical outcomes**:
- `pm.Categorical('y', p=probs)`
**See:** `references/distributions.md` for comprehensive distribution reference
## Sampling and Inference
### MCMC with NUTS
Default and recommended for most models:
```python
idata = pm.sample(
draws=2000,
tune=1000,
chains=4,
target_accept=0.9,
random_seed=42
)
```
**Adjust when needed:**
- Divergences → `target_accept=0.95` or higher
- Slow sampling → Use ADVI for initialization
- Discrete parameters → Use `pm.Metropolis()` for discrete vars
### Variational Inference
Fast approximation for exploration or initialization:
```python
with model:
approx = pm.fit(n=20000, method='advi')
# Use for initialization
start = approx.sample(return_inferencedata=False)[0]
idata = pm.sample(start=start)
```
**Trade-offs:**
- Much faster than MCMC
- Approximate (may underestimate uncertainty)
- Good for large models or quick exploration
**See:** `references/sampling_inference.md` for detailed sampling guide
## Diagnostic Scripts
### Comprehensive Diagnostics
```python
from scripts.model_diagnostics import create_diagnostic_report
create_diagnostic_report(
idata,
var_names=['alpha', 'beta', 'sigma'],
output_dir='diagnostics/'
)
```
Creates:
- Trace plots
- Rank plots (mixing check)
- Autocorrelation plots
- Energy plots
- ESS evolution
- Summary statistics CSV
### Quick Diagnostic Check
```python
from scripts.model_diagnostics import check_diagnostics
results = check_diagnostics(idata)
```
Checks R-hat, ESS, divergences, and tree depth.
## Common Issues and Solutions
### Divergences
**Symptom:** `idata.sample_stats.diverging.sum() > 0`
**Solutions:**
1. Increase `target_accept=0.95` or `0.99`
2. Use non-centered parameterization (hierarchical models)
3. Add stronger priors to constrain parameters
4. Check for model misspecification
### Low Effective Sample Size
**Symptom:** `ESS < 400`
**Solutions:**
1. Sample more draws: `draws=5000`
2. Reparameterize to reduce posterior correlation
3. Use QR decomposition for regression with correlated predictors
### High R-hat
**Symptom:** `R-hat > 1.01`
**Solutions:**
1. Run longer chains: `tune=2000, draws=5000`
2. Check for multimodality
3. Improve initialization with ADVI
### Slow Sampling
**Solutions:**
1. Use ADVI initialization
2. Reduce model complexity
3. Increase parallelization: `cores=8, chains=8`
4. Use variational inference if appropriate
## Best Practices
### Model Building
1. **Always standardize predictors** for better sampling
2. **Use weakly informative priors** (not flat)
3. **Use named dimensions** (`dims`) for clarity
4. **Non-centered parameterization** for hierarchical models
5. **Check prior predictive** before fitting
### Sampling
1. **Run multiple chains** (at least 4) for convergence
2. **Use `target_accept=0.9`** as baseline (higher if needed)
3. **Include `log_likelihood=True`** for model comparison
4. **Set random seed** for reproducibility
### Validation
1. **Check diagnostics** before interpretation (R-hat, ESS, divergences)
2. **Posterior predictive check** for model validation
3. **Compare multiple models** when appropriate
4. **Report uncertainty** (HDI intervals, not just point estimates)
### Workflow
1. Start simple, add complexity gradually
2. Prior predictive check → Fit → Diagnostics → Posterior predictive check
3. Iterate on model specification based on checks
4. Document assumptions and prior choices
## Resources
This skill includes:
### References (`references/`)
- **`distributions.md`**: Comprehensive catalog of PyMC distributions organized by category (continuous, discrete, multivariate, mixture, time series). Use when selecting priors or likelihoods.
- **`sampling_inference.md`**: Detailed guide to sampling algorithms (NUTS, Metropolis, SMC), variational inference (ADVI, SVGD), and handling sampling issues. Use when encountering convergence problems or choosing inference methods.
- **`workflows.md`**: Complete workflow examples and code patterns for common model types, data preparation, prior selection, and model validation. Use as a cookbook for standard Bayesian analyses.
### Scripts (`scripts/`)
- **`model_diagnostics.py`**: Automated diagnostic checking and report generation. Functions: `check_diagnostics()` for quick checks, `create_diagnostic_report()` for comprehensive analysis with plots.
- **`model_comparison.py`**: Model comparison utilities using LOO/WAIC. Functions: `compare_models()`, `check_loo_reliability()`, `model_averaging()`.
### Templates (`assets/`)
- **`linear_regression_template.py`**: Complete template for Bayesian linear regression with full workflow (data prep, prior checks, fitting, diagnostics, predictions).
- **`hierarchical_model_template.py`**: Complete template for hierarchical/multilevel models with non-centered parameterization and group-level analysis.
## Quick Reference
### Model Building
```python
with pm.Model(coords={'var': names}) as model:
# Priors
param = pm.Normal('param', mu=0, sigma=1, dims='var')
# Likelihood
y = pm.Normal('y', mu=..., sigma=..., observed=data)
```
### Sampling
```python
idata = pm.sample(draws=2000, tune=1000, chains=4, target_accept=0.9)
```
### Diagnostics
```python
from scripts.model_diagnostics import check_diagnostics
check_diagnostics(idata)
```
### Model Comparison
```python
from scripts.model_comparison import compare_models
compare_models({'m1': idata1, 'm2': idata2}, ic='loo')
```
### Predictions
```python
with model:
pm.set_data({'X': X_new})
pred = pm.sample_posterior_predictive(idata.posterior)
```
## Additional Notes
- PyMC integrates with ArviZ for visualization and diagnostics
- Use `pm.model_to_graphviz(model)` to visualize model structure
- Save results with `idata.to_netcdf('results.nc')`
- Load with `az.from_netcdf('results.nc')`
- For very large models, consider minibatch ADVI or data subsampling
---
## Referenced Files
> The following files are referenced in this skill and included for context.
### assets/linear_regression_template.py
```python
"""
PyMC Linear Regression Template
This template provides a complete workflow for Bayesian linear regression,
including data preparation, model building, diagnostics, and predictions.
Customize the sections marked with # TODO
"""
import pymc as pm
import arviz as az
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# =============================================================================
# 1. DATA PREPARATION
# =============================================================================
# TODO: Load your data
# Example:
# df = pd.read_csv('data.csv')
# X = df[['predictor1', 'predictor2', 'predictor3']].values
# y = df['outcome'].values
# For demonstration:
np.random.seed(42)
n_samples = 100
n_predictors = 3
X = np.random.randn(n_samples, n_predictors)
true_beta = np.array([1.5, -0.8, 2.1])
true_alpha = 0.5
y = true_alpha + X @ true_beta + np.random.randn(n_samples) * 0.5
# Standardize predictors for better sampling
X_mean = X.mean(axis=0)
X_std = X.std(axis=0)
X_scaled = (X - X_mean) / X_std
# =============================================================================
# 2. BUILD MODEL
# =============================================================================
# TODO: Customize predictor names
predictor_names = ['predictor1', 'predictor2', 'predictor3']
coords = {
'predictors': predictor_names,
'obs_id': np.arange(len(y))
}
with pm.Model(coords=coords) as linear_model:
# Priors
# TODO: Adjust prior parameters based on your domain knowledge
alpha = pm.Normal('alpha', mu=0, sigma=1)
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
sigma = pm.HalfNormal('sigma', sigma=1)
# Linear predictor
mu = alpha + pm.math.dot(X_scaled, beta)
# Likelihood
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs_id')
# =============================================================================
# 3. PRIOR PREDICTIVE CHECK
# =============================================================================
print("Running prior predictive check...")
with linear_model:
prior_pred = pm.sample_prior_predictive(samples=1000, random_seed=42)
# Visualize prior predictions
fig, ax = plt.subplots(figsize=(10, 6))
az.plot_ppc(prior_pred, group='prior', num_pp_samples=100, ax=ax)
ax.set_title('Prior Predictive Check')
plt.tight_layout()
plt.savefig('prior_predictive_check.png', dpi=300, bbox_inches='tight')
print("Prior predictive check saved to 'prior_predictive_check.png'")
# =============================================================================
# 4. FIT MODEL
# =============================================================================
print("\nFitting model...")
with linear_model:
# Optional: Quick ADVI exploration
# approx = pm.fit(n=20000, random_seed=42)
# MCMC sampling
idata = pm.sample(
draws=2000,
tune=1000,
chains=4,
target_accept=0.9,
random_seed=42,
idata_kwargs={'log_likelihood': True}
)
print("Sampling complete!")
# =============================================================================
# 5. CHECK DIAGNOSTICS
# =============================================================================
print("\n" + "="*60)
print("DIAGNOSTICS")
print("="*60)
# Summary statistics
summary = az.summary(idata, var_names=['alpha', 'beta', 'sigma'])
print("\nParameter Summary:")
print(summary)
# Check convergence
bad_rhat = summary[summary['r_hat'] > 1.01]
if len(bad_rhat) > 0:
print(f"\n⚠️ WARNING: {len(bad_rhat)} parameters with R-hat > 1.01")
print(bad_rhat[['r_hat']])
else:
print("\n✓ All R-hat values < 1.01 (good convergence)")
# Check effective sample size
low_ess = summary[summary['ess_bulk'] < 400]
if len(low_ess) > 0:
print(f"\n⚠️ WARNING: {len(low_ess)} parameters with ESS < 400")
print(low_ess[['ess_bulk', 'ess_tail']])
else:
print("\n✓ All ESS values > 400 (sufficient samples)")
# Check divergences
divergences = idata.sample_stats.diverging.sum().item()
if divergences > 0:
print(f"\n⚠️ WARNING: {divergences} divergent transitions")
print(" Consider increasing target_accept or reparameterizing")
else:
print("\n✓ No divergences")
# Trace plots
fig, axes = plt.subplots(len(['alpha', 'beta', 'sigma']), 2, figsize=(12, 8))
az.plot_trace(idata, var_names=['alpha', 'beta', 'sigma'], axes=axes)
plt.tight_layout()
plt.savefig('trace_plots.png', dpi=300, bbox_inches='tight')
print("\nTrace plots saved to 'trace_plots.png'")
# =============================================================================
# 6. POSTERIOR PREDICTIVE CHECK
# =============================================================================
print("\nRunning posterior predictive check...")
with linear_model:
pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42)
# Visualize fit
fig, ax = plt.subplots(figsize=(10, 6))
az.plot_ppc(idata, num_pp_samples=100, ax=ax)
ax.set_title('Posterior Predictive Check')
plt.tight_layout()
plt.savefig('posterior_predictive_check.png', dpi=300, bbox_inches='tight')
print("Posterior predictive check saved to 'posterior_predictive_check.png'")
# =============================================================================
# 7. ANALYZE RESULTS
# =============================================================================
# Posterior distributions
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
az.plot_posterior(idata, var_names=['alpha', 'beta', 'sigma'], ax=axes)
plt.tight_layout()
plt.savefig('posterior_distributions.png', dpi=300, bbox_inches='tight')
print("Posterior distributions saved to 'posterior_distributions.png'")
# Forest plot for coefficients
fig, ax = plt.subplots(figsize=(8, 6))
az.plot_forest(idata, var_names=['beta'], combined=True, ax=ax)
ax.set_title('Coefficient Estimates (95% HDI)')
ax.set_yticklabels(predictor_names)
plt.tight_layout()
plt.savefig('coefficient_forest_plot.png', dpi=300, bbox_inches='tight')
print("Forest plot saved to 'coefficient_forest_plot.png'")
# Print coefficient estimates
print("\n" + "="*60)
print("COEFFICIENT ESTIMATES")
print("="*60)
beta_samples = idata.posterior['beta']
for i, name in enumerate(predictor_names):
mean = beta_samples.sel(predictors=name).mean().item()
hdi = az.hdi(beta_samples.sel(predictors=name), hdi_prob=0.95)
print(f"{name:20s}: {mean:7.3f} [95% HDI: {hdi.values[0]:7.3f}, {hdi.values[1]:7.3f}]")
# =============================================================================
# 8. PREDICTIONS FOR NEW DATA
# =============================================================================
# TODO: Provide new data for predictions
# X_new = np.array([[...], [...], ...]) # New predictor values
# For demonstration, use some test data
X_new = np.random.randn(10, n_predictors)
X_new_scaled = (X_new - X_mean) / X_std
# Update model data and predict
with linear_model:
pm.set_data({'X_scaled': X_new_scaled, 'obs_id': np.arange(len(X_new))})
post_pred = pm.sample_posterior_predictive(
idata.posterior,
var_names=['y_obs'],
random_seed=42
)
# Extract predictions
y_pred_samples = post_pred.posterior_predictive['y_obs']
y_pred_mean = y_pred_samples.mean(dim=['chain', 'draw']).values
y_pred_hdi = az.hdi(y_pred_samples, hdi_prob=0.95).values
print("\n" + "="*60)
print("PREDICTIONS FOR NEW DATA")
print("="*60)
print(f"{'Index':<10} {'Mean':<15} {'95% HDI Lower':<15} {'95% HDI Upper':<15}")
print("-"*60)
for i in range(len(X_new)):
print(f"{i:<10} {y_pred_mean[i]:<15.3f} {y_pred_hdi[i, 0]:<15.3f} {y_pred_hdi[i, 1]:<15.3f}")
# =============================================================================
# 9. SAVE RESULTS
# =============================================================================
# Save InferenceData
idata.to_netcdf('linear_regression_results.nc')
print("\nResults saved to 'linear_regression_results.nc'")
# Save summary to CSV
summary.to_csv('model_summary.csv')
print("Summary saved to 'model_summary.csv'")
print("\n" + "="*60)
print("ANALYSIS COMPLETE")
print("="*60)
```
### assets/hierarchical_model_template.py
```python
"""
PyMC Hierarchical/Multilevel Model Template
This template provides a complete workflow for Bayesian hierarchical models,
useful for grouped/nested data (e.g., students within schools, patients within hospitals).
Customize the sections marked with # TODO
"""
import pymc as pm
import arviz as az
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# =============================================================================
# 1. DATA PREPARATION
# =============================================================================
# TODO: Load your data with group structure
# Example:
# df = pd.read_csv('data.csv')
# groups = df['group_id'].values
# X = df['predictor'].values
# y = df['outcome'].values
# For demonstration: Generate hierarchical data
np.random.seed(42)
n_groups = 10
n_per_group = 20
n_obs = n_groups * n_per_group
# True hierarchical structure
true_mu_alpha = 5.0
true_sigma_alpha = 2.0
true_mu_beta = 1.5
true_sigma_beta = 0.5
true_sigma = 1.0
group_alphas = np.random.normal(true_mu_alpha, true_sigma_alpha, n_groups)
group_betas = np.random.normal(true_mu_beta, true_sigma_beta, n_groups)
# Generate data
groups = np.repeat(np.arange(n_groups), n_per_group)
X = np.random.randn(n_obs)
y = group_alphas[groups] + group_betas[groups] * X + np.random.randn(n_obs) * true_sigma
# TODO: Customize group names
group_names = [f'Group_{i}' for i in range(n_groups)]
# =============================================================================
# 2. BUILD HIERARCHICAL MODEL
# =============================================================================
print("Building hierarchical model...")
coords = {
'groups': group_names,
'obs': np.arange(n_obs)
}
with pm.Model(coords=coords) as hierarchical_model:
# Data containers (for later predictions)
X_data = pm.Data('X_data', X)
groups_data = pm.Data('groups_data', groups)
# Hyperpriors (population-level parameters)
# TODO: Adjust hyperpriors based on your domain knowledge
mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=10)
sigma_alpha = pm.HalfNormal('sigma_alpha', sigma=5)
mu_beta = pm.Normal('mu_beta', mu=0, sigma=10)
sigma_beta = pm.HalfNormal('sigma_beta', sigma=5)
# Group-level parameters (non-centered parameterization)
# Non-centered parameterization improves sampling efficiency
alpha_offset = pm.Normal('alpha_offset', mu=0, sigma=1, dims='groups')
alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_offset, dims='groups')
beta_offset = pm.Normal('beta_offset', mu=0, sigma=1, dims='groups')
beta = pm.Deterministic('beta', mu_beta + sigma_beta * beta_offset, dims='groups')
# Observation-level model
mu = alpha[groups_data] + beta[groups_data] * X_data
# Observation noise
sigma = pm.HalfNormal('sigma', sigma=5)
# Likelihood
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs')
print("Model built successfully!")
print(f"Groups: {n_groups}")
print(f"Observations: {n_obs}")
# =============================================================================
# 3. PRIOR PREDICTIVE CHECK
# =============================================================================
print("\nRunning prior predictive check...")
with hierarchical_model:
prior_pred = pm.sample_prior_predictive(samples=500, random_seed=42)
# Visualize prior predictions
fig, ax = plt.subplots(figsize=(10, 6))
az.plot_ppc(prior_pred, group='prior', num_pp_samples=100, ax=ax)
ax.set_title('Prior Predictive Check')
plt.tight_layout()
plt.savefig('hierarchical_prior_check.png', dpi=300, bbox_inches='tight')
print("Prior predictive check saved to 'hierarchical_prior_check.png'")
# =============================================================================
# 4. FIT MODEL
# =============================================================================
print("\nFitting hierarchical model...")
print("(This may take a few minutes due to model complexity)")
with hierarchical_model:
# MCMC sampling with higher target_accept for hierarchical models
idata = pm.sample(
draws=2000,
tune=2000, # More tuning for hierarchical models
chains=4,
target_accept=0.95, # Higher for better convergence
random_seed=42,
idata_kwargs={'log_likelihood': True}
)
print("Sampling complete!")
# =============================================================================
# 5. CHECK DIAGNOSTICS
# =============================================================================
print("\n" + "="*60)
print("DIAGNOSTICS")
print("="*60)
# Summary for key parameters
summary = az.summary(
idata,
var_names=['mu_alpha', 'sigma_alpha', 'mu_beta', 'sigma_beta', 'sigma', 'alpha', 'beta']
)
print("\nParameter Summary:")
print(summary)
# Check convergence
bad_rhat = summary[summary['r_hat'] > 1.01]
if len(bad_rhat) > 0:
print(f"\n⚠️ WARNING: {len(bad_rhat)} parameters with R-hat > 1.01")
print(bad_rhat[['r_hat']])
else:
print("\n✓ All R-hat values < 1.01 (good convergence)")
# Check effective sample size
low_ess = summary[summary['ess_bulk'] < 400]
if len(low_ess) > 0:
print(f"\n⚠️ WARNING: {len(low_ess)} parameters with ESS < 400")
print(low_ess[['ess_bulk']].head(10))
else:
print("\n✓ All ESS values > 400 (sufficient samples)")
# Check divergences
divergences = idata.sample_stats.diverging.sum().item()
if divergences > 0:
print(f"\n⚠️ WARNING: {divergences} divergent transitions")
print(" This is common in hierarchical models - non-centered parameterization already applied")
print(" Consider even higher target_accept or stronger hyperpriors")
else:
print("\n✓ No divergences")
# Trace plots for hyperparameters
fig, axes = plt.subplots(5, 2, figsize=(12, 12))
az.plot_trace(
idata,
var_names=['mu_alpha', 'sigma_alpha', 'mu_beta', 'sigma_beta', 'sigma'],
axes=axes
)
plt.tight_layout()
plt.savefig('hierarchical_trace_plots.png', dpi=300, bbox_inches='tight')
print("\nTrace plots saved to 'hierarchical_trace_plots.png'")
# =============================================================================
# 6. POSTERIOR PREDICTIVE CHECK
# =============================================================================
print("\nRunning posterior predictive check...")
with hierarchical_model:
pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42)
# Visualize fit
fig, ax = plt.subplots(figsize=(10, 6))
az.plot_ppc(idata, num_pp_samples=100, ax=ax)
ax.set_title('Posterior Predictive Check')
plt.tight_layout()
plt.savefig('hierarchical_posterior_check.png', dpi=300, bbox_inches='tight')
print("Posterior predictive check saved to 'hierarchical_posterior_check.png'")
# =============================================================================
# 7. ANALYZE HIERARCHICAL STRUCTURE
# =============================================================================
print("\n" + "="*60)
print("POPULATION-LEVEL (HYPERPARAMETER) ESTIMATES")
print("="*60)
# Population-level estimates
hyper_summary = summary.loc[['mu_alpha', 'sigma_alpha', 'mu_beta', 'sigma_beta', 'sigma']]
print(hyper_summary[['mean', 'sd', 'hdi_3%', 'hdi_97%']])
# Forest plot for group-level parameters
fig, axes = plt.subplots(1, 2, figsize=(14, 8))
# Group intercepts
az.plot_forest(idata, var_names=['alpha'], combined=True, ax=axes[0])
axes[0].set_title('Group-Level Intercepts (α)')
axes[0].set_yticklabels(group_names)
axes[0].axvline(idata.posterior['mu_alpha'].mean().item(), color='red', linestyle='--', label='Population mean')
axes[0].legend()
# Group slopes
az.plot_forest(idata, var_names=['beta'], combined=True, ax=axes[1])
axes[1].set_title('Group-Level Slopes (β)')
axes[1].set_yticklabels(group_names)
axes[1].axvline(idata.posterior['mu_beta'].mean().item(), color='red', linestyle='--', label='Population mean')
axes[1].legend()
plt.tight_layout()
plt.savefig('group_level_estimates.png', dpi=300, bbox_inches='tight')
print("\nGroup-level estimates saved to 'group_level_estimates.png'")
# Shrinkage visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# Intercepts
alpha_samples = idata.posterior['alpha'].values.reshape(-1, n_groups)
alpha_means = alpha_samples.mean(axis=0)
mu_alpha_mean = idata.posterior['mu_alpha'].mean().item()
axes[0].scatter(range(n_groups), alpha_means, alpha=0.6)
axes[0].axhline(mu_alpha_mean, color='red', linestyle='--', label='Population mean')
axes[0].set_xlabel('Group')
axes[0].set_ylabel('Intercept')
axes[0].set_title('Group Intercepts (showing shrinkage to population mean)')
axes[0].legend()
# Slopes
beta_samples = idata.posterior['beta'].values.reshape(-1, n_groups)
beta_means = beta_samples.mean(axis=0)
mu_beta_mean = idata.posterior['mu_beta'].mean().item()
axes[1].scatter(range(n_groups), beta_means, alpha=0.6)
axes[1].axhline(mu_beta_mean, color='red', linestyle='--', label='Population mean')
axes[1].set_xlabel('Group')
axes[1].set_ylabel('Slope')
axes[1].set_title('Group Slopes (showing shrinkage to population mean)')
axes[1].legend()
plt.tight_layout()
plt.savefig('shrinkage_plot.png', dpi=300, bbox_inches='tight')
print("Shrinkage plot saved to 'shrinkage_plot.png'")
# =============================================================================
# 8. PREDICTIONS FOR NEW DATA
# =============================================================================
# TODO: Specify new data
# For existing groups:
# new_X = np.array([...])
# new_groups = np.array([0, 1, 2, ...]) # Existing group indices
# For a new group (predict using population-level parameters):
# Just use mu_alpha and mu_beta
print("\n" + "="*60)
print("PREDICTIONS FOR NEW DATA")
print("="*60)
# Example: Predict for existing groups
new_X = np.array([-2, -1, 0, 1, 2])
new_groups = np.array([0, 2, 4, 6, 8]) # Select some groups
with hierarchical_model:
pm.set_data({'X_data': new_X, 'groups_data': new_groups, 'obs': np.arange(len(new_X))})
post_pred = pm.sample_posterior_predictive(
idata.posterior,
var_names=['y_obs'],
random_seed=42
)
y_pred_samples = post_pred.posterior_predictive['y_obs']
y_pred_mean = y_pred_samples.mean(dim=['chain', 'draw']).values
y_pred_hdi = az.hdi(y_pred_samples, hdi_prob=0.95).values
print(f"Predictions for existing groups:")
print(f"{'Group':<10} {'X':<10} {'Mean':<15} {'95% HDI Lower':<15} {'95% HDI Upper':<15}")
print("-"*65)
for i, g in enumerate(new_groups):
print(f"{group_names[g]:<10} {new_X[i]:<10.2f} {y_pred_mean[i]:<15.3f} {y_pred_hdi[i, 0]:<15.3f} {y_pred_hdi[i, 1]:<15.3f}")
# Predict for a new group (using population parameters)
print(f"\nPrediction for a NEW group (using population-level parameters):")
new_X_newgroup = np.array([0.0])
# Manually compute using population parameters
mu_alpha_samples = idata.posterior['mu_alpha'].values.flatten()
mu_beta_samples = idata.posterior['mu_beta'].values.flatten()
sigma_samples = idata.posterior['sigma'].values.flatten()
# Predicted mean for new group
y_pred_newgroup = mu_alpha_samples + mu_beta_samples * new_X_newgroup[0]
y_pred_mean_newgroup = y_pred_newgroup.mean()
y_pred_hdi_newgroup = az.hdi(y_pred_newgroup, hdi_prob=0.95)
print(f"X = {new_X_newgroup[0]:.2f}")
print(f"Predicted mean: {y_pred_mean_newgroup:.3f}")
print(f"95% HDI: [{y_pred_hdi_newgroup[0]:.3f}, {y_pred_hdi_newgroup[1]:.3f}]")
# =============================================================================
# 9. SAVE RESULTS
# =============================================================================
idata.to_netcdf('hierarchical_model_results.nc')
print("\nResults saved to 'hierarchical_model_results.nc'")
summary.to_csv('hierarchical_model_summary.csv')
print("Summary saved to 'hierarchical_model_summary.csv'")
print("\n" + "="*60)
print("ANALYSIS COMPLETE")
print("="*60)
```
### references/distributions.md
```markdown
# PyMC Distributions Reference
This reference provides a comprehensive catalog of probability distributions available in PyMC, organized by category. Use this to select appropriate distributions for priors and likelihoods when building Bayesian models.
## Continuous Distributions
Continuous distributions define probability densities over real-valued domains.
### Common Continuous Distributions
**`pm.Normal(name, mu, sigma)`**
- Normal (Gaussian) distribution
- Parameters: `mu` (mean), `sigma` (standard deviation)
- Support: (-∞, ∞)
- Common uses: Default prior for unbounded parameters, likelihood for continuous data with additive noise
**`pm.HalfNormal(name, sigma)`**
- Half-normal distribution (positive half of normal)
- Parameters: `sigma` (standard deviation)
- Support: [0, ∞)
- Common uses: Prior for scale/standard deviation parameters
**`pm.Uniform(name, lower, upper)`**
- Uniform distribution
- Parameters: `lower`, `upper` (bounds)
- Support: [lower, upper]
- Common uses: Weakly informative prior when parameter must be bounded
**`pm.Beta(name, alpha, beta)`**
- Beta distribution
- Parameters: `alpha`, `beta` (shape parameters)
- Support: [0, 1]
- Common uses: Prior for probabilities and proportions
**`pm.Gamma(name, alpha, beta)`**
- Gamma distribution
- Parameters: `alpha` (shape), `beta` (rate)
- Support: (0, ∞)
- Common uses: Prior for positive parameters, rate parameters
**`pm.Exponential(name, lam)`**
- Exponential distribution
- Parameters: `lam` (rate parameter)
- Support: [0, ∞)
- Common uses: Prior for scale parameters, waiting times
**`pm.LogNormal(name, mu, sigma)`**
- Log-normal distribution
- Parameters: `mu`, `sigma` (parameters of underlying normal)
- Support: (0, ∞)
- Common uses: Prior for positive parameters with multiplicative effects
**`pm.StudentT(name, nu, mu, sigma)`**
- Student's t-distribution
- Parameters: `nu` (degrees of freedom), `mu` (location), `sigma` (scale)
- Support: (-∞, ∞)
- Common uses: Robust alternative to normal for outlier-resistant models
**`pm.Cauchy(name, alpha, beta)`**
- Cauchy distribution
- Parameters: `alpha` (location), `beta` (scale)
- Support: (-∞, ∞)
- Common uses: Heavy-tailed alternative to normal
### Specialized Continuous Distributions
**`pm.Laplace(name, mu, b)`** - Laplace (double exponential) distribution
**`pm.AsymmetricLaplace(name, kappa, mu, b)`** - Asymmetric Laplace distribution
**`pm.InverseGamma(name, alpha, beta)`** - Inverse gamma distribution
**`pm.Weibull(name, alpha, beta)`** - Weibull distribution for reliability analysis
**`pm.Logistic(name, mu, s)`** - Logistic distribution
**`pm.LogitNormal(name, mu, sigma)`** - Logit-normal distribution for (0,1) support
**`pm.Pareto(name, alpha, m)`** - Pareto distribution for power-law phenomena
**`pm.ChiSquared(name, nu)`** - Chi-squared distribution
**`pm.ExGaussian(name, mu, sigma, nu)`** - Exponentially modified Gaussian
**`pm.VonMises(name, mu, kappa)`** - Von Mises (circular normal) distribution
**`pm.SkewNormal(name, mu, sigma, alpha)`** - Skew-normal distribution
**`pm.Triangular(name, lower, c, upper)`** - Triangular distribution
**`pm.Gumbel(name, mu, beta)`** - Gumbel distribution for extreme values
**`pm.Rice(name, nu, sigma)`** - Rice (Rician) distribution
**`pm.Moyal(name, mu, sigma)`** - Moyal distribution
**`pm.Kumaraswamy(name, a, b)`** - Kumaraswamy distribution (Beta alternative)
**`pm.Interpolated(name, x_points, pdf_points)`** - Custom distribution from interpolation
## Discrete Distributions
Discrete distributions define probabilities over integer-valued domains.
### Common Discrete Distributions
**`pm.Bernoulli(name, p)`**
- Bernoulli distribution (binary outcome)
- Parameters: `p` (success probability)
- Support: {0, 1}
- Common uses: Binary classification, coin flips
**`pm.Binomial(name, n, p)`**
- Binomial distribution
- Parameters: `n` (number of trials), `p` (success probability)
- Support: {0, 1, ..., n}
- Common uses: Number of successes in fixed trials
**`pm.Poisson(name, mu)`**
- Poisson distribution
- Parameters: `mu` (rate parameter)
- Support: {0, 1, 2, ...}
- Common uses: Count data, rates, occurrences
**`pm.Categorical(name, p)`**
- Categorical distribution
- Parameters: `p` (probability vector)
- Support: {0, 1, ..., K-1}
- Common uses: Multi-class classification
**`pm.DiscreteUniform(name, lower, upper)`**
- Discrete uniform distribution
- Parameters: `lower`, `upper` (bounds)
- Support: {lower, ..., upper}
- Common uses: Uniform prior over finite integers
**`pm.NegativeBinomial(name, mu, alpha)`**
- Negative binomial distribution
- Parameters: `mu` (mean), `alpha` (dispersion)
- Support: {0, 1, 2, ...}
- Common uses: Overdispersed count data
**`pm.Geometric(name, p)`**
- Geometric distribution
- Parameters: `p` (success probability)
- Support: {0, 1, 2, ...}
- Common uses: Number of failures before first success
### Specialized Discrete Distributions
**`pm.BetaBinomial(name, alpha, beta, n)`** - Beta-binomial (overdispersed binomial)
**`pm.HyperGeometric(name, N, k, n)`** - Hypergeometric distribution
**`pm.DiscreteWeibull(name, q, beta)`** - Discrete Weibull distribution
**`pm.OrderedLogistic(name, eta, cutpoints)`** - Ordered logistic for ordinal data
**`pm.OrderedProbit(name, eta, cutpoints)`** - Ordered probit for ordinal data
## Multivariate Distributions
Multivariate distributions define joint probability distributions over vector-valued random variables.
### Common Multivariate Distributions
**`pm.MvNormal(name, mu, cov)`**
- Multivariate normal distribution
- Parameters: `mu` (mean vector), `cov` (covariance matrix)
- Common uses: Correlated continuous variables, Gaussian processes
**`pm.Dirichlet(name, a)`**
- Dirichlet distribution
- Parameters: `a` (concentration parameters)
- Support: Simplex (sums to 1)
- Common uses: Prior for probability vectors, topic modeling
**`pm.Multinomial(name, n, p)`**
- Multinomial distribution
- Parameters: `n` (number of trials), `p` (probability vector)
- Common uses: Count data across multiple categories
**`pm.MvStudentT(name, nu, mu, cov)`**
- Multivariate Student's t-distribution
- Parameters: `nu` (degrees of freedom), `mu` (location), `cov` (scale matrix)
- Common uses: Robust multivariate modeling
### Specialized Multivariate Distributions
**`pm.LKJCorr(name, n, eta)`** - LKJ correlation matrix prior (for correlation matrices)
**`pm.LKJCholeskyCov(name, n, eta, sd_dist)`** - LKJ prior with Cholesky decomposition
**`pm.Wishart(name, nu, V)`** - Wishart distribution (for covariance matrices)
**`pm.InverseWishart(name, nu, V)`** - Inverse Wishart distribution
**`pm.MatrixNormal(name, mu, rowcov, colcov)`** - Matrix normal distribution
**`pm.KroneckerNormal(name, mu, covs, sigma)`** - Kronecker-structured normal
**`pm.CAR(name, mu, W, alpha, tau)`** - Conditional autoregressive (spatial)
**`pm.ICAR(name, W, sigma)`** - Intrinsic conditional autoregressive (spatial)
## Mixture Distributions
Mixture distributions combine multiple component distributions.
**`pm.Mixture(name, w, comp_dists)`**
- General mixture distribution
- Parameters: `w` (weights), `comp_dists` (component distributions)
- Common uses: Clustering, multi-modal data
**`pm.NormalMixture(name, w, mu, sigma)`**
- Mixture of normal distributions
- Common uses: Mixture of Gaussians clustering
### Zero-Inflated and Hurdle Models
**`pm.ZeroInflatedPoisson(name, psi, mu)`** - Excess zeros in count data
**`pm.ZeroInflatedBinomial(name, psi, n, p)`** - Zero-inflated binomial
**`pm.ZeroInflatedNegativeBinomial(name, psi, mu, alpha)`** - Zero-inflated negative binomial
**`pm.HurdlePoisson(name, psi, mu)`** - Hurdle Poisson (two-part model)
**`pm.HurdleGamma(name, psi, alpha, beta)`** - Hurdle gamma
**`pm.HurdleLogNormal(name, psi, mu, sigma)`** - Hurdle log-normal
## Time Series Distributions
Distributions designed for temporal data and sequential modeling.
**`pm.AR(name, rho, sigma, init_dist)`**
- Autoregressive process
- Parameters: `rho` (AR coefficients), `sigma` (innovation std), `init_dist` (initial distribution)
- Common uses: Time series modeling, sequential data
**`pm.GaussianRandomWalk(name, mu, sigma, init_dist)`**
- Gaussian random walk
- Parameters: `mu` (drift), `sigma` (step size), `init_dist` (initial value)
- Common uses: Cumulative processes, random walk priors
**`pm.MvGaussianRandomWalk(name, mu, cov, init_dist)`**
- Multivariate Gaussian random walk
**`pm.GARCH11(name, omega, alpha_1, beta_1)`**
- GARCH(1,1) volatility model
- Common uses: Financial time series, volatility modeling
**`pm.EulerMaruyama(name, dt, sde_fn, sde_pars, init_dist)`**
- Stochastic differential equation via Euler-Maruyama discretization
- Common uses: Continuous-time processes
## Special Distributions
**`pm.Deterministic(name, var)`**
- Deterministic transformation (not a random variable)
- Use for computed quantities derived from other variables
**`pm.Potential(name, logp)`**
- Add arbitrary log-probability contribution
- Use for custom likelihood components or constraints
**`pm.Flat(name)`**
- Improper flat prior (constant density)
- Use sparingly; can cause sampling issues
**`pm.HalfFlat(name)`**
- Improper flat prior on positive reals
- Use sparingly; can cause sampling issues
## Distribution Modifiers
**`pm.Truncated(name, dist, lower, upper)`**
- Truncate any distribution to specified bounds
**`pm.Censored(name, dist, lower, upper)`**
- Handle censored observations (observed bounds, not exact values)
**`pm.CustomDist(name, ..., logp, random)`**
- Define custom distributions with user-specified log-probability and random sampling functions
**`pm.Simulator(name, fn, params, ...)`**
- Custom distributions via simulation (for likelihood-free inference)
## Usage Tips
### Choosing Priors
1. **Scale parameters** (σ, τ): Use `HalfNormal`, `HalfCauchy`, `Exponential`, or `Gamma`
2. **Probabilities**: Use `Beta` or `Uniform(0, 1)`
3. **Unbounded parameters**: Use `Normal` or `StudentT` (for robustness)
4. **Positive parameters**: Use `LogNormal`, `Gamma`, or `Exponential`
5. **Correlation matrices**: Use `LKJCorr`
6. **Count data**: Use `Poisson` or `NegativeBinomial` (for overdispersion)
### Shape Broadcasting
PyMC distributions support NumPy-style broadcasting. Use the `shape` parameter to create vectors or arrays of random variables:
```python
# Vector of 5 independent normals
beta = pm.Normal('beta', mu=0, sigma=1, shape=5)
# 3x4 matrix of independent gammas
tau = pm.Gamma('tau', alpha=2, beta=1, shape=(3, 4))
```
### Using dims for Named Dimensions
Instead of shape, use `dims` for more readable models:
```python
with pm.Model(coords={'predictors': ['age', 'income', 'education']}) as model:
beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
```
```
### references/sampling_inference.md
```markdown
# PyMC Sampling and Inference Methods
This reference covers the sampling algorithms and inference methods available in PyMC for posterior inference.
## MCMC Sampling Methods
### Primary Sampling Function
**`pm.sample(draws=1000, tune=1000, chains=4, **kwargs)`**
The main interface for MCMC sampling in PyMC.
**Key Parameters:**
- `draws`: Number of samples to draw per chain (default: 1000)
- `tune`: Number of tuning/warmup samples (default: 1000, discarded)
- `chains`: Number of parallel chains (default: 4)
- `cores`: Number of CPU cores to use (default: all available)
- `target_accept`: Target acceptance rate for step size tuning (default: 0.8, increase to 0.9-0.95 for difficult posteriors)
- `random_seed`: Random seed for reproducibility
- `return_inferencedata`: Return ArviZ InferenceData object (default: True)
- `idata_kwargs`: Additional kwargs for InferenceData creation (e.g., `{"log_likelihood": True}` for model comparison)
**Returns:** InferenceData object containing posterior samples, sampling statistics, and diagnostics
**Example:**
```python
with pm.Model() as model:
# ... define model ...
idata = pm.sample(draws=2000, tune=1000, chains=4, target_accept=0.9)
```
### Sampling Algorithms
PyMC automatically selects appropriate samplers based on model structure, but you can specify algorithms manually.
#### NUTS (No-U-Turn Sampler)
**Default algorithm** for continuous parameters. Highly efficient Hamiltonian Monte Carlo variant.
- Automatically tunes step size and mass matrix
- Adaptive: explores posterior geometry during tuning
- Best for smooth, continuous posteriors
- Can struggle with high correlation or multimodality
**Manual specification:**
```python
with model:
idata = pm.sample(step=pm.NUTS(target_accept=0.95))
```
**When to adjust:**
- Increase `target_accept` (0.9-0.99) if seeing divergences
- Use `init='adapt_diag'` for faster initialization (default)
- Use `init='jitter+adapt_diag'` for difficult initializations
#### Metropolis
General-purpose Metropolis-Hastings sampler.
- Works for both continuous and discrete variables
- Less efficient than NUTS for smooth continuous posteriors
- Useful for discrete parameters or non-differentiable models
- Requires manual tuning
**Example:**
```python
with model:
idata = pm.sample(step=pm.Metropolis())
```
#### Slice Sampler
Slice sampling for univariate distributions.
- No tuning required
- Good for difficult univariate posteriors
- Can be slow for high dimensions
**Example:**
```python
with model:
idata = pm.sample(step=pm.Slice())
```
#### CompoundStep
Combine different samplers for different parameters.
**Example:**
```python
with model:
# Use NUTS for continuous params, Metropolis for discrete
step1 = pm.NUTS([continuous_var1, continuous_var2])
step2 = pm.Metropolis([discrete_var])
idata = pm.sample(step=[step1, step2])
```
### Sampling Diagnostics
PyMC automatically computes diagnostics. Check these before trusting results:
#### Effective Sample Size (ESS)
Measures independent information in correlated samples.
- **Rule of thumb**: ESS > 400 per chain (1600 total for 4 chains)
- Low ESS indicates high autocorrelation
- Access via: `az.ess(idata)`
#### R-hat (Gelman-Rubin statistic)
Measures convergence across chains.
- **Rule of thumb**: R-hat < 1.01 for all parameters
- R-hat > 1.01 indicates non-convergence
- Access via: `az.rhat(idata)`
#### Divergences
Indicate regions where NUTS struggled.
- **Rule of thumb**: 0 divergences (or very few)
- Divergences suggest biased samples
- **Fix**: Increase `target_accept`, reparameterize, or use stronger priors
- Access via: `idata.sample_stats.diverging.sum()`
#### Energy Plot
Visualizes Hamiltonian Monte Carlo energy transitions.
```python
az.plot_energy(idata)
```
Good separation between energy distributions indicates healthy sampling.
### Handling Sampling Issues
#### Divergences
```python
# Increase target acceptance rate
idata = pm.sample(target_accept=0.95)
# Or reparameterize using non-centered parameterization
# Bad (centered):
mu = pm.Normal('mu', 0, 1)
sigma = pm.HalfNormal('sigma', 1)
x = pm.Normal('x', mu, sigma, observed=data)
# Good (non-centered):
mu = pm.Normal('mu', 0, 1)
sigma = pm.HalfNormal('sigma', 1)
x_offset = pm.Normal('x_offset', 0, 1, observed=(data - mu) / sigma)
```
#### Slow Sampling
```python
# Use fewer tuning steps if model is simple
idata = pm.sample(tune=500)
# Increase cores for parallelization
idata = pm.sample(cores=8, chains=8)
# Use variational inference for initialization
with model:
approx = pm.fit() # Run ADVI
idata = pm.sample(start=approx.sample(return_inferencedata=False)[0])
```
#### High Autocorrelation
```python
# Increase draws
idata = pm.sample(draws=5000)
# Reparameterize to reduce correlation
# Consider using QR decomposition for regression models
```
## Variational Inference
Faster approximate inference for large models or quick exploration.
### ADVI (Automatic Differentiation Variational Inference)
**`pm.fit(n=10000, method='advi', **kwargs)`**
Approximates posterior with simpler distribution (typically mean-field Gaussian).
**Key Parameters:**
- `n`: Number of iterations (default: 10000)
- `method`: VI algorithm ('advi', 'fullrank_advi', 'svgd')
- `random_seed`: Random seed
**Returns:** Approximation object for sampling and analysis
**Example:**
```python
with model:
approx = pm.fit(n=50000)
# Draw samples from approximation
idata = approx.sample(1000)
# Or sample for MCMC initialization
start = approx.sample(return_inferencedata=False)[0]
```
**Trade-offs:**
- **Pros**: Much faster than MCMC, scales to large data
- **Cons**: Approximate, may miss posterior structure, underestimates uncertainty
### Full-Rank ADVI
Captures correlations between parameters.
```python
with model:
approx = pm.fit(method='fullrank_advi')
```
More accurate than mean-field but slower.
### SVGD (Stein Variational Gradient Descent)
Non-parametric variational inference.
```python
with model:
approx = pm.fit(method='svgd', n=20000)
```
Better captures multimodality but more computationally expensive.
## Prior and Posterior Predictive Sampling
### Prior Predictive Sampling
Sample from the prior distribution (before seeing data).
**`pm.sample_prior_predictive(samples=500, **kwargs)`**
**Purpose:**
- Validate priors are reasonable
- Check implied predictions before fitting
- Ensure model generates plausible data
**Example:**
```python
with model:
prior_pred = pm.sample_prior_predictive(samples=1000)
# Visualize prior predictions
az.plot_ppc(prior_pred, group='prior')
```
### Posterior Predictive Sampling
Sample from posterior predictive distribution (after fitting).
**`pm.sample_posterior_predictive(trace, **kwargs)`**
**Purpose:**
- Model validation via posterior predictive checks
- Generate predictions for new data
- Assess goodness-of-fit
**Example:**
```python
with model:
# After sampling
idata = pm.sample()
# Add posterior predictive samples
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
# Posterior predictive check
az.plot_ppc(idata)
```
### Predictions for New Data
Update data and sample predictive distribution:
```python
with model:
# Original model fit
idata = pm.sample()
# Update with new predictor values
pm.set_data({'X': X_new})
# Sample predictions
post_pred_new = pm.sample_posterior_predictive(
idata.posterior,
var_names=['y_pred']
)
```
## Maximum A Posteriori (MAP) Estimation
Find posterior mode (point estimate).
**`pm.find_MAP(start=None, method='L-BFGS-B', **kwargs)`**
**When to use:**
- Quick point estimates
- Initialization for MCMC
- When full posterior not needed
**Example:**
```python
with model:
map_estimate = pm.find_MAP()
print(map_estimate)
```
**Limitations:**
- Doesn't quantify uncertainty
- Can find local optima in multimodal posteriors
- Sensitive to prior specification
## Inference Recommendations
### Standard Workflow
1. **Start with ADVI** for quick exploration:
```python
approx = pm.fit(n=20000)
```
2. **Run MCMC** for full inference:
```python
idata = pm.sample(draws=2000, tune=1000)
```
3. **Check diagnostics**:
```python
az.summary(idata, var_names=['~mu_log__']) # Exclude transformed vars
```
4. **Sample posterior predictive**:
```python
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
```
### Choosing Inference Method
| Scenario | Recommended Method |
|----------|-------------------|
| Small-medium models, need full uncertainty | MCMC with NUTS |
| Large models, initial exploration | ADVI |
| Discrete parameters | Metropolis or marginalize |
| Hierarchical models with divergences | Non-centered parameterization + NUTS |
| Very large data | Minibatch ADVI |
| Quick point estimates | MAP or ADVI |
### Reparameterization Tricks
**Non-centered parameterization** for hierarchical models:
```python
# Centered (can cause divergences):
mu = pm.Normal('mu', 0, 10)
sigma = pm.HalfNormal('sigma', 1)
theta = pm.Normal('theta', mu, sigma, shape=n_groups)
# Non-centered (better sampling):
mu = pm.Normal('mu', 0, 10)
sigma = pm.HalfNormal('sigma', 1)
theta_offset = pm.Normal('theta_offset', 0, 1, shape=n_groups)
theta = pm.Deterministic('theta', mu + sigma * theta_offset)
```
**QR decomposition** for correlated predictors:
```python
import numpy as np
# QR decomposition
Q, R = np.linalg.qr(X)
with pm.Model():
# Uncorrelated coefficients
beta_tilde = pm.Normal('beta_tilde', 0, 1, shape=p)
# Transform back to original scale
beta = pm.Deterministic('beta', pm.math.solve(R, beta_tilde))
mu = pm.math.dot(Q, beta_tilde)
sigma = pm.HalfNormal('sigma', 1)
y = pm.Normal('y', mu, sigma, observed=y_obs)
```
## Advanced Sampling
### Sequential Monte Carlo (SMC)
For complex posteriors or model evidence estimation:
```python
with model:
idata = pm.sample_smc(draws=2000, chains=4)
```
Good for multimodal posteriors or when NUTS struggles.
### Custom Initialization
Provide starting values:
```python
start = {'mu': 0, 'sigma': 1}
with model:
idata = pm.sample(start=start)
```
Or use MAP estimate:
```python
with model:
start = pm.find_MAP()
idata = pm.sample(start=start)
```
```