causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRMMultiTreatmentIRM
Bases: sklearn.base.BaseEstimator
Interactive Regression Model (IRM) for multi-treatment unconfoundedness.
DoubleML-style cross-fitting estimator consuming MultiCausalData and
producing pairwise contrasts between each active treatment arm and the
baseline arm (column 0). The model supports K >= 2 mutually exclusive
treatment arms encoded as one-hot columns.
Parameters
- dataMultiCausalData
Data container with outcome, one-hot treatment indicators, and confounders.
- ml_gestimator
Learner for . If classifier and
Yis binary,predict_probais used; otherwisepredict()is used.- ml_mclassifier
Learner for the generalized propensity score . Must support
predict_proba().- n_foldsint, default 5
Number of cross-fitting folds.
- n_repint, default 1
Number of repetitions of sample splitting. Currently only 1 is supported.
- normalize_ipwbool, default False
Whether to normalize inverse-probability terms within the score. Applied to
score="ATE"only. Forscore="ATTE", normalization is ignored to preserve the canonical orthogonal ATTE score used by the estimator.- trimming_rule{“truncate”}, default “truncate”
Trimming approach for propensity scores.
- trimming_thresholdfloat, default 1e-2
Lower threshold used before renormalizing multiclass propensities back to the simplex.
- random_stateOptional[int], default None
Random seed for fold creation.
- n_jobsint, default 1
Number of parallel jobs for fold-level cross-fitting. Use
-1to use all available CPUs. Practical guidance: - Start withn_jobs=1for stable, low-contention defaults. - Increase ton_jobs=2/4/-1when cross-fitting is the bottleneck. - If nuisance learners are already multithreaded (e.g. CatBoost withthread_count=-1), keepn_jobs=1or set learner threads to1to avoid CPU oversubscription.- store_diagnosticsbool, default True
Whether to retain raw fit-time arrays and diagnostic-only artifacts on the fitted model. Set to
Falsefor a lighter-weight estimator that still supports effect estimation while omitting heavier caches such as confounders, raw propensities, and fold assignments.
Examples
Notes
Let where and arm is the designated baseline. Define the arm-specific outcome regressions and generalized propensity scores as
Under multi-arm unconfoundedness and overlap,
the pairwise baseline ATE for arm is
The corresponding pairwise ATTE conditions on membership in arm :
This implementation cross-fits all arm-specific outcome nuisances and all class propensities . The propensity vector is lower-trimmed componentwise and then renormalized onto the simplex so that each row still sums to one.
Estimation solves the sample moment equation
which yields the closed-form estimate
For the pairwise ATE, the score component for each active arm is
with . Here and denotes the trimmed-and-renormalized propensity for arm .
For the pairwise ATTE, let . Because is observed for treated units in arm , the orthogonal (doubly robust) score takes the baseline-regression form
For ATTE, in the solved moment equation and the returned estimate object keeps the same shape and fields as for ATE.
Canonical target
causalis.scenarios.multi_unconfoundedness.model.MultiTreatmentIRM
Sections