Scenario8 min read

Multi Unconfoundedness

Automated conversion of multi_unconfoundedness.ipynb

Multi Unconfoundedness

Multi Unconfoundedness is observational/quasi experimental scenario of causal inference with multiple treatments. Treatments are not randomly assigned so we need to control confounders to estimate causal effect

Data

In our research we should estimate effect of different gamification mechanics. Our outcome is number of sessions per month.

Treatments:

  • d_0: No mechanics used
  • d_1: Used first set of mechanics
  • d_2: Used second set of mechanics

DGP is from Causalis. Read more at https://causalis.causalcraft.com/articles/generate_multitreatment_gamma_26

Result
yd_0d_1d_2tenure_monthsavg_sessions_weekspend_last_monthpremium_userurban_residentsupport_tickets_q...m_obs_d_1tau_link_d_1m_d_2m_obs_d_2tau_link_d_2g_d_0g_d_1g_d_2cate_d_1cate_d_2
00.4227691.00.00.027.6566053.19866789.6094640.01.00.0...0.246687-0.3520050.2207810.2207810.4941663.2793842.3063145.375338-0.9730702.095954
17.5662311.00.00.023.7983863.362415102.3372360.00.03.0...0.179393-0.3073600.2369580.2369580.4202782.8078502.0648534.274630-0.7429971.466780
21.7026620.00.01.028.4250093.391819102.6607120.01.01.0...0.210566-0.3201890.2182450.2182450.5024153.0699192.2287985.073677-0.8411212.003758
31.8275301.00.00.018.8600664.07117583.5934170.00.02.0...0.176729-0.3162410.2376390.2376390.4416772.7168051.9802344.225485-0.7365711.508680
41.4298430.01.00.017.8530873.14007579.2098700.01.01.0...0.232492-0.3501300.2470270.2470270.4936243.2243542.2718695.282273-0.9524852.057919

5 rows × 26 columns

Result

Ground truth ATE for d_1 vs d_0 is -1.1950325692907122 Ground truth ATE for d_2 vs d_0 is 2.530398527003894

Result

MultiCausalData(df=(100000, 12), treatment_names=['d_0', 'd_1', 'd_2'], control_treatment='d_0')outcome='y', confounders=['tenure_months', 'avg_sessions_week', 'spend_last_month', 'premium_user', 'urban_resident', 'support_tickets_q', 'discount_eligible', 'credit_utilization'], user_id=None,

EDA

Result
treatmentcountmeanstdminp10p25medianp75p90max
0d_0501153.7584173.1067250.0154270.8879061.6263262.9378634.9574157.57778550.239323
1d_2250086.5417175.5397080.0431251.5126102.7756375.1026118.58491313.34876179.125235
2d_1248772.9808172.4127630.0090220.7119971.3067742.3522343.9464635.98507025.169272
Result

png

Result

png

Result
treatmentnoutlier_countoutlier_ratelower_boundupper_boundhas_outliersmethodtail
0d_05011522880.045655-3.3703089.954048Trueiqrboth
1d_22500811730.046905-5.93827717.298826Trueiqrboth
2d_12487710670.042891-2.6527607.905997Trueiqrboth

As we see there are heavy tails in distribution of outcome. We won't trim them because high activity means high income of revenue for us

Result
confoundersmean_d_0mean_d_1abs_diffsmdks_pvalue
0avg_sessions_week4.8279575.3300500.5020920.2531870.00000
1premium_user0.2174400.2968610.0794210.1824660.00000
2tenure_months23.67246225.7520632.0796010.1767030.00000
3spend_last_month82.89471996.06289813.1681800.1497090.00000
4discount_eligible0.3258700.3956260.0697560.1456420.00000
5urban_resident0.5859120.6384210.0525090.1079190.00000
6support_tickets_q1.4781401.4923020.0141620.0115580.47358
7credit_utilization0.4496270.4489960.000632-0.0058110.86692
Result
confoundersmean_d_0mean_d_1abs_diffsmdks_pvalue
0premium_user0.2174400.2740720.0566320.1318230.00000
1avg_sessions_week4.8279575.0594940.2315360.1167470.00000
2spend_last_month82.89471989.3340216.4393020.0762050.00000
3support_tickets_q1.4781401.5693780.0912380.0738830.00000
4discount_eligible0.3258700.3560060.0301360.0636050.00000
5urban_resident0.5859120.6046870.0187740.0382560.00002
6tenure_months23.67246223.3913370.281125-0.0241310.00373
7credit_utilization0.4496270.4518550.0022280.0204930.02836

And data is highly biased by confounders

Inference

Explanation of MultiTreatmentIRM

0) Assumptions

  • SUTVA / consistency: no interference, no hidden treatment versions, and observed outcome equals the potential outcome under realized arm.
  • Multi-arm unconfoundedness:
(Y(0),Y(1),,Y(K1))DX\left(Y(0),Y(1),\dots,Y(K-1)\right) \perp D \mid X

where DD is one-hot with baseline arm D0D_0.

  • Positivity / overlap:
Pr(Dk=1X=x)>0,k=0,,K1.\Pr(D_k=1\mid X=x)>0,\quad \forall k=0,\dots,K-1.

In practice, the implementation enforces stability with propensity trimming.

1) Data and estimand

For each unit ii we observe Yi,Xi,DiY_i, X_i, D_i with one-hot Di=(Di0,,Di,K1)D_i=(D_{i0},\dots,D_{i,K-1}) and kDik=1\sum_k D_{ik}=1.

MultiTreatmentIRM estimates vector contrasts against baseline:

θk=E[Y(k)Y(0)],k=1,,K1.\theta_k = \mathbb{E}[Y(k)-Y(0)],\quad k=1,\dots,K-1.

So outputs are pairwise ATEs such as d_1 vs d_0, d_2 vs d_0.

2) Nuisance functions

For each arm kk:

gk(x)=E[YX=x,Dk=1],mk(x)=Pr(Dk=1X=x), k=0K1mk(x)=1.g_k(x)=\mathbb{E}[Y\mid X=x, D_k=1],\qquad m_k(x)=\Pr(D_k=1\mid X=x),\ \sum_{k=0}^{K-1}m_k(x)=1.

These are estimated out-of-fold by cross-fitting (to reduce overfitting bias in final moments).

3) Cross-fitting logic

With folds I1,,IFI_1,\dots,I_F:

  1. Train multiclass propensity model on IfcI_f^c, predict m^k\hat m_k on IfI_f.
  2. For each arm kk, train outcome model on rows in IfcI_f^c where Dk=1D_k=1, predict g^k\hat g_k on IfI_f.
  3. Repeat for all folds and stitch predictions.

4) Multiclass trimming

Predicted propensities are stabilized by lower-bound trimming and row renormalization:

m~ik=max(m^ik,ε),m^iktrim=m~ikj=0K1m~ij.\tilde m_{ik}=\max(\hat m_{ik},\varepsilon),\qquad \hat m^{trim}_{ik}=\frac{\tilde m_{ik}}{\sum_{j=0}^{K-1}\tilde m_{ij}}.

This keeps each row on the probability simplex and avoids exploding IPW weights.

5) Orthogonal score for each contrast

Define residuals and IPW representers:

uik=Yig^k(Xi),hik=Dikm^iktrim.u_{ik}=Y_i-\hat g_k(X_i),\qquad h_{ik}=\frac{D_{ik}}{\hat m^{trim}_{ik}}.

(If normalize_ipw=True, hikh_{ik} is column-normalized in Hajek style.)

For each active arm k>0k>0 vs baseline 00:

ψb,ik=(g^k(Xi)g^0(Xi))+uikhikui0hi0,ψa=1.\psi_{b,ik}=\big(\hat g_k(X_i)-\hat g_0(X_i)\big)+u_{ik}h_{ik}-u_{i0}h_{i0}, \qquad \psi_a=-1.

Moment condition:

En[ψaθk+ψb,k]=0  θ^k=En[ψb,k].\mathbb{E}_n[\psi_a\theta_k+\psi_{b,\cdot k}]=0 \ \Rightarrow\ \hat\theta_k=\mathbb{E}_n[\psi_{b,\cdot k}].

6) Inference

Influence function per contrast:

IF^ik=ψb,ikθ^k.\widehat{IF}_{ik}=\psi_{b,ik}-\hat\theta_k.

Then

SE^(θ^k)=Varn(IF^k)n,\widehat{SE}(\hat\theta_k)=\sqrt{\frac{\mathrm{Var}_n(\widehat{IF}_{\cdot k})}{n}},

with Wald CI

θ^k±z1α/2SE^(θ^k).\hat\theta_k \pm z_{1-\alpha/2}\widehat{SE}(\hat\theta_k).

P-values are normal-approximation; significance flag is Bonferroni-adjusted across K1K-1 contrasts.

7) Relative effect reported by the model

Baseline mean is estimated via orthogonal signal:

ψμc,i=g^0(Xi)+ui0hi0,μ^c=En[ψμc].\psi_{\mu_c,i}=\hat g_0(X_i)+u_{i0}h_{i0},\qquad \hat\mu_c=\mathbb{E}_n[\psi_{\mu_c}].

Relative effect (%):

τ^krel=100θ^kμ^c,\hat\tau_k^{rel}=100\cdot\frac{\hat\theta_k}{\hat\mu_c},

with CI from delta-method variance (as implemented in model.py).

Result
d_1 vs d_0d_2 vs d_0
field
estimandATEATE
modelMultiTreatmentIRMMultiTreatmentIRM
value-1.1832 (ci_abs: -1.2233, -1.1431)2.5298 (ci_abs: 2.4565, 2.6031)
value_relative-29.9638 (ci_rel: -30.8175, -29.1101)64.0643 (ci_rel: 61.9680, 66.1606)
alpha0.05000.0500
p_value0.00000.0000
is_significantTrueTrue
n_treated2487725008
n_control5011550115
treatment_mean2.98086.5417
control_mean3.75843.7584
time2026-02-222026-02-22

Refutation

Unconfoundedness

Result
comparisonmetricvalueflag
0d_0 vs d_1balance_max_smd0.013307GREEN
1d_0 vs d_1balance_frac_violations0.0GREEN
2d_0 vs d_1balance_passTrueGREEN
3d_0 vs d_2balance_max_smd0.004068GREEN
4d_0 vs d_2balance_frac_violations0.0GREEN
5d_0 vs d_2balance_passTrueGREEN
6overallbalance_max_smd0.013307GREEN
7overallbalance_frac_violations0.0GREEN
8overallbalance_passTrueGREEN

Sensitivity

Result
cf_yr2_yr2_drhotheta_longtheta_shortdelta
d_1 vs d_03.022036e-073.022035e-071.351745e-07-1.0-1.183219-1.146914-0.036304
d_2 vs d_03.022036e-073.022035e-071.461506e-07-1.02.5297902.4707080.059082
Result

{'theta': array([-1.18321853, 2.52979027]), 'se': array([0.02044409, 0.03741762]), 'alpha': 0.05, 'z': 1.959963984540054, 'H0': 0.0, 'sampling_ci': array([[-1.22328822, -1.14314885], [ 2.45645308, 2.60312746]]), 'theta_bounds_cofounding': array([[-1.61866521, -0.74777186], [ 2.1016666 , 2.95791394]]), 'bias_aware_ci': array([[-1.65985899, -0.70803275], [ 2.02990876, 3.03312867]]), 'max_bias_base': array([8.48841832, 8.34566666]), 'max_bias': array([0.43544667, 0.42812367]), 'bound_width': array([0.43544667, 0.42812367]), 'sigma2': 11.79593879649448, 'nu2': array([6.10830955, 5.90458744]), 'rv': array([0.27985187, 0.64760316]), 'rva': array([0.26617826, 0.63406235]), 'contrast_labels': ['d_1 vs d_0', 'd_2 vs d_0'], 'params': {'cf_y': 0.05, 'r2_d': array([0.05, 0.05]), 'rho': array([1., 1.]), 'use_signed_rr': False}}

Overlap

Result

png

Result
comparisonmetricvalueflag
0d_0 vs d_1edge_0.01_below0.0GREEN
1d_0 vs d_1edge_0.01_above0.0GREEN
2d_0 vs d_1KS0.141067GREEN
3d_0 vs d_1AUC0.596814GREEN
4d_0 vs d_1ESS_treated_ratio0.895799GREEN
5d_0 vs d_1ESS_baseline_ratio0.954377GREEN
6d_0 vs d_1clip_m_total0.0GREEN
7d_0 vs d_1overlap_passTrueGREEN
8d_0 vs d_2edge_0.01_below0.0GREEN
9d_0 vs d_2edge_0.01_above0.0GREEN
10d_0 vs d_2KS0.074956GREEN
11d_0 vs d_2AUC0.549007GREEN
12d_0 vs d_2ESS_treated_ratio0.948692GREEN
13d_0 vs d_2ESS_baseline_ratio0.954377GREEN
14d_0 vs d_2clip_m_total0.0GREEN
15d_0 vs d_2overlap_passTrueGREEN

SUTVA

Result

1.) Are your clients independent (i). Outcome of ones do not depend on others? 2.) Are all clients have full window to measure metrics? 3.) Do you measure confounders before treatment and outcome after? 4.) Do you have a consistent label of treatment, such as if a person does not receive a treatment, he has a label 0?

Score

Result
comparisonmetricvalueflag
0d_1 vs d_0se_plugin2.044409e-02NA
1d_1 vs d_0psi_p99_over_med9.656081e+00GREEN
2d_1 vs d_0psi_kurtosis6.103705e+01RED
3d_1 vs d_0max_|t|_gk5.488800e+00RED
4d_1 vs d_0max_|t|_g04.107007e+00RED
5d_1 vs d_0max_|t|_mk1.723246e+00RED
6d_1 vs d_0max_|t|_m09.548751e-01RED
7d_1 vs d_0max_|t|5.488800e+00RED
8d_1 vs d_0oos_tstat_fold-1.028761e-15GREEN
9d_1 vs d_0oos_tstat_strict-1.779456e-15GREEN
10d_2 vs d_0se_plugin3.741762e-02NA
11d_2 vs d_0psi_p99_over_med1.610667e+01YELLOW
12d_2 vs d_0psi_kurtosis5.568245e+01RED
13d_2 vs d_0max_|t|_gk5.333558e+00RED
14d_2 vs d_0max_|t|_g04.107007e+00RED
15d_2 vs d_0max_|t|_mk1.431120e+00RED
16d_2 vs d_0max_|t|_m09.548751e-01RED
17d_2 vs d_0max_|t|5.333558e+00RED
18d_2 vs d_0oos_tstat_fold-4.861275e-16GREEN
19d_2 vs d_0oos_tstat_strict-6.684270e-16GREEN
Result

png

Conclusion

Set of mechanics labeled d_2 performed better and has effect 2.5355 (ci_abs: 2.4600, 2.6111) sessions per user. However, set of mechanics labeled d_1 perform worse than without mechanics -1.1772 (ci_abs: -1.2174, -1.1370) sessions. We need to turn them off