CausalDatasetGenerator
Generate synthetic causal inference datasets with controllable confounding, treatment prevalence, noise, and (optionally) heterogeneous treatment effects.
Data model (high level)
- confounders X ∈ R^k are drawn from user-specified distributions.
- Binary treatment D is assigned by a logistic model: D ~ Bernoulli( sigmoid(alpha_d + f_d(X) + u_strength_d * U) ), where f_d(X) = (X @ beta_d + g_d(X)) * propensity_sharpness, and U ~ N(0,1) is an optional unobserved confounder.
- Outcome Y depends on treatment and confounders with link determined by
outcome_type: outcome_type = "continuous": Y = alpha_y + f_y(X) + u_strength_y * U + T * tau(X) + ε, ε ~ N(0, sigma_y^2) outcome_type = "binary": logit P(Y=1|T,X,U) = alpha_y + f_y(X) + u_strength_y * U + T * tau(X) outcome_type = "poisson": log E[Y|T,X,U] = alpha_y + f_y(X) + u_strength_y * U + T * tau(X) outcome_type = "gamma": log E[Y|T,X,U] = alpha_y + f_y(X) + u_strength_y * U + T * tau(X) where f_y(X) = X @ beta_y + g_y(X), and tau(X) is either constantthetaor a user function.
Returned columns
- y: outcome
- d: binary treatment (0/1)
- x1..xk (or user-provided names)
- m: true propensity P(T=1 | X) marginalized over U
- m_obs: realized propensity P(T=1 | X, U)
- tau_link: tau(X) on the structural (link) scale
- g0: E[Y | X, T=0] on the natural outcome scale marginalized over U .,9
- g1: E[Y | X, T=1] on the natural outcome scale marginalized over U
- cate: g1 - g0 (conditional average treatment effect on the natural outcome scale)
Notes on effect scale:
- For "continuous",
theta(or tau(X)) is an additive mean difference, sotau_link == cate. - For "binary", tau acts on the log-odds scale.
cateis reported as a risk difference. - For "poisson" and "gamma", tau acts on the log-mean scale.
cateis reported on the mean scale.
Parameters
- theta (
float) – Constant treatment effect used iftauis None. - tau (
callable) – Function tau(X) -> array-like shape (n,) for heterogeneous effects. - beta_y (
array - like) – Linear coefficients of confounders in the outcome baseline f_y(X). - beta_d (
array - like) – Linear coefficients of confounders in the treatment score f_d(X). - g_y (
callable) – Nonlinear/additive function g_y(X) -> (n,) added to the outcome baseline. - g_d (
callable) – Nonlinear/additive function g_d(X) -> (n,) added to the treatment score. - alpha_y (
float) – Outcome intercept (natural scale for continuous; log-odds for binary; log-mean for Poisson/Gamma). - alpha_d (
float) – Treatment intercept (log-odds). Iftarget_d_rateis set,alpha_dis auto-calibrated. - sigma_y (
float) – Std. dev. of the Gaussian noise for continuous outcomes. - outcome_type (
('continuous', 'binary', 'poisson', 'gamma', 'tweedie')) – Outcome family and link as defined above. - confounder_specs (
list of dict) – Schema for generating confounders. See_gaussian_copulafor details. - k (
int) – Number of confounders whenconfounder_specsis None. Defaults to independent N(0,1). - x_sampler (
callable) – Custom sampler (n, k, seed) -> X ndarray of shape (n,k). Overridesconfounder_specs. - use_copula (
bool) – If True andconfounder_specsprovided, use Gaussian copula for X. - copula_corr (
array - like) – Correlation matrix for copula. - target_d_rate (
float) – Target treatment prevalence (propensity mean). Calibratesalpha_d. - u_strength_d (
float) – Strength of the unobserved confounder U in treatment assignment. - u_strength_y (
float) – Strength of the unobserved confounder U in the outcome. - propensity_sharpness (
float) – Scales the X-driven treatment score to adjust positivity difficulty. - seed (
int) – Random seed for reproducibility.
Attributes
Functions
- generate – Draw a synthetic dataset of size
n. - oracle_nuisance – Return nuisance functions (m(x), g0(x), g1(x)) compatible with IRM.
- to_causal_data – Generate a dataset and convert it to a CausalData object.
alpha_d
alpha_y
alpha_zi
beta_d
beta_y
beta_zi
confounder_specs
copula_corr
g_d
g_y
g_zi
gamma_shape
generate
Draw a synthetic dataset of size n.
Parameters
- n (
int) – Number of samples to generate. - U (
ndarray) – Unobserved confounder. If None, generated from N(0,1).
Returns
DataFrame– The generated dataset with outcome 'y', treatment 'd', confounders, and oracle ground-truth columns.
include_oracle
k
lognormal_sigma
oracle_nuisance
Return nuisance functions (m(x), g0(x), g1(x)) compatible with IRM.
Parameters
- num_quad (
int) – Number of quadrature points for marginalizing over U.
Returns
dict– Dictionary of callables mapping X to nuisance values.
outcome_type
pos_dist
propensity_sharpness
rng
score_bounding
seed
sigma_y
target_d_rate
tau
tau_zi
theta
to_causal_data
Generate a dataset and convert it to a CausalData object.
Parameters
- n (
int) – Number of samples to generate. - confounders (
str or list of str) – List of confounder column names to include. If None, automatically detects numeric confounders.
Returns
CausalData– A CausalData object containing the generated dataset.