Skip to content
Submodule
causalis.scenarios.multi_unconfoundedness.model

model

Submodule causalis.scenarios.multi_unconfoundedness.model with no child pages and 11 documented members.

Classes

Jump directly into the documented classes for this page.

1 items
class
causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM

MultiTreatmentIRM

Bases: sklearn.base.BaseEstimator

Interactive Regression Model (IRM) for multi-treatment unconfoundedness.

DoubleML-style cross-fitting estimator consuming MultiCausalData and producing pairwise contrasts between each active treatment arm and the baseline arm (column 0). The model supports K >= 2 mutually exclusive treatment arms encoded as one-hot columns.

Parameters

dataMultiCausalData

Data container with outcome, one-hot treatment indicators, and confounders.

ml_gestimator

Learner for E[YX,D=k]\mathbb{E}[Y \mid X, D=k]. If classifier and Y is binary, predict_proba is used; otherwise predict() is used.

ml_mclassifier

Learner for the generalized propensity score P(D=kX)\mathbb{P}(D=k \mid X). Must support predict_proba().

n_foldsint, default 5

Number of cross-fitting folds.

n_repint, default 1

Number of repetitions of sample splitting. Currently only 1 is supported.

normalize_ipwbool, default False

Whether to normalize inverse-probability terms within the score. Applied to score="ATE" only. For score="ATTE", normalization is ignored to preserve the canonical orthogonal ATTE score used by the estimator.

trimming_rule{“truncate”}, default “truncate”

Trimming approach for propensity scores.

trimming_thresholdfloat, default 1e-2

Lower threshold used before renormalizing multiclass propensities back to the simplex.

random_stateOptional[int], default None

Random seed for fold creation.

n_jobsint, default 1

Number of parallel jobs for fold-level cross-fitting. Use -1 to use all available CPUs. Practical guidance: - Start with n_jobs=1 for stable, low-contention defaults. - Increase to n_jobs=2/4/-1 when cross-fitting is the bottleneck. - If nuisance learners are already multithreaded (e.g. CatBoost with thread_count=-1), keep n_jobs=1 or set learner threads to 1 to avoid CPU oversubscription.

store_diagnosticsbool, default True

Whether to retain raw fit-time arrays and diagnostic-only artifacts on the fitted model. Set to False for a lighter-weight estimator that still supports effect estimation while omitting heavier caches such as confounders, raw propensities, and fold assignments.

Examples

Notes

Let W=(Y,D,X)W = (Y, D, X) where D{0,1,,K1}D \in \{0, 1, \dots, K-1\} and arm 00 is the designated baseline. Define the arm-specific outcome regressions and generalized propensity scores as

g0,k(x)=E[YD=k,X=x],m0,k(x)=P(D=kX=x).g_{0, k}(x) = \mathbb{E}[Y \mid D=k, X=x], \qquad m_{0, k}(x) = \mathbb{P}(D=k \mid X=x).

Under multi-arm unconfoundedness and overlap,

(Y(0),,Y(K1))DX,0<m0,k(X)<1 a.s. for all k,(Y(0), \dots, Y(K-1)) \perp D \mid X, \qquad 0 < m_{0, k}(X) < 1 \ \text{a.s. for all } k,

the pairwise baseline ATE for arm k>0k > 0 is

θ0,kATE=E[g0,k(X)g0,0(X)].\theta_{0, k}^{ATE} = \mathbb{E}[g_{0, k}(X) - g_{0, 0}(X)].

The corresponding pairwise ATTE conditions on membership in arm kk:

θ0,kATTE=E[g0,k(X)g0,0(X)D=k].\theta_{0, k}^{ATTE} = \mathbb{E}[g_{0, k}(X) - g_{0, 0}(X) \mid D=k].

This implementation cross-fits all arm-specific outcome nuisances g^k(X)\hat g_k(X) and all class propensities m^k(X)\hat m_k(X). The propensity vector is lower-trimmed componentwise and then renormalized onto the simplex so that each row still sums to one.

Estimation solves the sample moment equation

En[ψa(Wi;η^)θk+ψb,k(Wi;η^)]=0,\mathbb{E}_n[\psi_a(W_i; \hat\eta)\theta_k + \psi_{b, k}(W_i; \hat\eta)] = 0,

which yields the closed-form estimate

θ^k=En[ψb,k(Wi;η^)]En[ψa(Wi;η^)].\hat\theta_k = -\frac{\mathbb{E}_n[\psi_{b, k}(W_i; \hat\eta)]} {\mathbb{E}_n[\psi_a(W_i; \hat\eta)]}.

For the pairwise ATE, the score component for each active arm k>0k > 0 is

ψb,kATE=g^k(X)g^0(X)+(Yg^k(X))dkm~k(X)(Yg^0(X))d0m~0(X),\psi_{b, k}^{ATE} = \hat g_k(X) - \hat g_0(X) + (Y - \hat g_k(X)) \frac{d_k}{\tilde m_k(X)} - (Y - \hat g_0(X)) \frac{d_0}{\tilde m_0(X)},

with ψa=1\psi_a = -1. Here dk=1{D=k}d_k = 1\{D=k\} and m~k\tilde m_k denotes the trimmed-and-renormalized propensity for arm kk.

For the pairwise ATTE, let pk=E[dk]p_k = \mathbb{E}[d_k]. Because Y(k)Y(k) is observed for treated units in arm kk, the orthogonal (doubly robust) score takes the baseline-regression form

ψb,kATTE=dkpk(Yg^0(X))d0pkm~k(X)m~0(X)(Yg^0(X)).\psi_{b, k}^{ATTE} = \frac{d_k}{p_k} (Y - \hat g_0(X)) - \frac{d_0}{p_k} \frac{\tilde m_k(X)}{\tilde m_0(X)} (Y - \hat g_0(X)).

For ATTE, ψa=1\psi_a = -1 in the solved moment equation and the returned estimate object keeps the same shape and fields as for ATE.

Canonical target

causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM

Sections

ParametersNotesExamples
Link to this symbol
method
causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.fit

fit

Canonical target

causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.fit

Link to this symbol
method
causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.estimate

estimate

Estimate pairwise baseline contrasts for each active treatment arm.

Parameters

score{“ATE”, “ATTE”}, default “ATE”

Target estimand. "ATE" estimates pairwise average treatment effects for each active arm versus baseline arm 0. "ATTE" estimates the corresponding pairwise average treatment effect on the treated for each active arm versus baseline arm 0 using the orthogonal / doubly robust ATTE score.

alphafloat, default 0.05

Two-sided significance level used for Wald confidence intervals.

diagnostic_databool, default True

Whether to attach the fitted diagnostic payload to the returned estimate.

Returns

MultiCausalEstimate

Result container holding one effect estimate per active arm versus the baseline arm, together with confidence intervals, p-values, relative effects, and optionally diagnostic payloads.

Canonical target

causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.estimate

Sections

ParametersReturns
Link to this symbol
property
causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.diagnostics_

diagnostics_

Return diagnostic data.

Returns

dict

Dictionary containing ‘m_hat’, ‘g_hat’ and ‘folds’.

Canonical target

causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.diagnostics_

Sections

Returns
Link to this symbol
property
causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.coef

coef

Return the estimated coefficient.

Returns

np.ndarray

The estimated coefficient.

Canonical target

causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.coef

Sections

Returns
Link to this symbol
property
causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.se

se

Return the standard error of the estimate.

Returns

np.ndarray

The standard error.

Canonical target

causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.se

Sections

Returns
Link to this symbol
property
causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.pvalues

pvalues

Return the p-values for the estimate.

Returns

np.ndarray

The p-values.

Canonical target

causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.pvalues

Sections

Returns
Link to this symbol
property
causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.summary

summary

Return a summary DataFrame of the results.

Returns

pd.DataFrame

The results summary.

Canonical target

causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.summary

Sections

Returns
Link to this symbol
property
causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.orth_signal

orth_signal

Return the cross-fitted orthogonal signal (psi_b).

Returns

np.ndarray

The orthogonal signal.

Canonical target

causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.orth_signal

Sections

Returns
Link to this symbol
method
causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.sensitivity_analysis

sensitivity_analysis

Canonical target

causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.sensitivity_analysis

Link to this symbol
method
causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.confint

confint

Canonical target

causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM.confint

Link to this symbol