Research5 min read

generate_obs_hte_26_rich()

Automated conversion of generate_obs_hte_26_rich.ipynb

generate_obs_hte_26_rich()

The generate_obs_hte_26_rich() function provides a more complex and realistic observational dataset. It features 11 confounders, complex dependencies (including derived features), a low treatment rate, and a Tweedie (zero-inflated Gamma) outcome model.

1. Confounders (XX)

The dataset contains 11 confounders X=(X1,,X11)TX = (X_1, \dots, X_{11})^T:

  • X1X_1: tenure_months
  • X2X_2: avg_sessions_week
  • X3X_3: spend_last_month
  • X4X_4: age_years
  • X5X_5: income_monthly
  • X6X_6: prior_purchases_12m
  • X7X_7: support_tickets_90d
  • X8X_8: premium_user
  • X9X_9: mobile_user
  • X10X_{10}: urban_resident
  • X11X_{11}: referred_user

Base Features Sampling: The base features X1,X2,X3,X4,X5,X10X_1, X_2, X_3, X_4, X_5, X_{10} are sampled using a Gaussian Copula with the following correlation matrix:

Σ=(1.000.200.200.300.200.000.201.000.450.200.300.100.200.451.000.100.400.100.300.200.101.000.200.100.200.300.400.201.000.100.000.100.100.100.101.00)\Sigma = \begin{pmatrix} 1.00 & 0.20 & 0.20 & 0.30 & 0.20 & 0.00 \\ 0.20 & 1.00 & 0.45 & -0.20 & 0.30 & 0.10 \\ 0.20 & 0.45 & 1.00 & -0.10 & 0.40 & 0.10 \\ 0.30 & -0.20 & -0.10 & 1.00 & 0.20 & -0.10 \\ 0.20 & 0.30 & 0.40 & 0.20 & 1.00 & 0.10 \\ 0.00 & 0.10 & 0.10 & -0.10 & 0.10 & 1.00 \end{pmatrix}

Derived Features: The remaining features are derived to mimic real-world behavioral dependencies:

  • Premium User (X8X_8): logit(P(X8=1))=5.0+0.7ln(1+X2)+0.45ln(1+X3)+0.35ln(1+X5)+0.015X1\text{logit}(P(X_8=1)) = -5.0 + 0.7 \ln(1+X_2) + 0.45 \ln(1+X_3) + 0.35 \ln(1+X_5) + 0.015 X_1
  • Mobile User (X9X_9): logit(P(X9=1))=1.50.03(X435)+0.30X10+0.25ln(1+X2)\text{logit}(P(X_9=1)) = 1.5 - 0.03(X_4 - 35) + 0.30 X_{10} + 0.25 \ln(1+X_2)
  • Referred User (X11X_{11}): logit(P(X11=1))=1.20.02(X112)+0.45X9+0.20X10\text{logit}(P(X_{11}=1)) = -1.2 - 0.02(X_1 - 12) + 0.45 X_9 + 0.20 X_{10}
  • Prior Purchases (X6X_6): X6Poisson(λ=clip(1.0+0.55ln(1+X2)+0.45ln(1+X3)+0.25X8,0.1,30.0))X_6 \sim \text{Poisson}(\lambda = \text{clip}(1.0 + 0.55 \ln(1+X_2) + 0.45 \ln(1+X_3) + 0.25 X_8, 0.1, 30.0))
  • Support Tickets (X7X_7): X7Poisson(λ=clip(0.6+0.25ln(1+X2)+0.30(1X8)+0.15tanh(X44512),0.05,10.0))X_7 \sim \text{Poisson}(\lambda = \text{clip}(0.6 + 0.25 \ln(1+X_2) + 0.30(1-X_8) + 0.15 \tanh(\frac{X_4 - 45}{12}), 0.05, 10.0))

2. Treatment Assignment (DD)

The treatment D{0,1}D \in \{0, 1\} is assigned with a target rate of 5%: P(D=1X)=σ(αd+bound(fd(X),2.0))P(D=1|X) = \sigma\left( \alpha_d + \text{bound}\left( f_d(X), 2.0 \right) \right) where fd(X)=j=111βd,jXj+gd(X)f_d(X) = \sum_{j=1}^{11} \beta_{d,j} X_j + g_d(X) with βd=(0.004, 0, 0, 0.012, 0.00005, 0.04, 0.22, 0.45, 0.08, 0.12, 0.10),\beta_d = (-0.004,\ 0,\ 0,\ -0.012,\ -0.00005,\ -0.04,\ 0.22,\ -0.45,\ -0.08,\ -0.12,\ 0.10), and

gd(X)=0.55tanh ⁣(ln1+X361)0.20ln1+X26tanh ⁣(X1241)0.25X8(X100.5)0.20X9tanh ⁣(ln1+X24)+0.30X11(1tanhX136)+0.35tanh ⁣(ln1+X72.5)0.25tanh ⁣(ln(1+X5)ln40011.3)0.12tanh ⁣(X124)0.10tanh ⁣(X44512).\begin{aligned} g_d(X) &= -0.55\tanh\!\left(\ln\frac{1+X_3}{61}\right) -0.20\ln\frac{1+X_2}{6}\tanh\!\left(\frac{X_1}{24}-1\right) -0.25X_8(X_{10}-0.5) -0.20X_9\tanh\!\left(\ln\frac{1+X_2}{4}\right)\\ &\quad +0.30X_{11}\left(1-\tanh\frac{X_1}{36}\right) +0.35\tanh\!\left(\ln\frac{1+X_7}{2.5}\right) -0.25\tanh\!\left(\frac{\ln(1+X_5)-\ln 4001}{1.3}\right) -0.12\tanh\!\left(\frac{X_1}{24}\right) -0.10\tanh\!\left(\frac{X_4-45}{12}\right). \end{aligned}

This design induces adverse selection (treated units can have lower observed outcomes than controls even when treatment helps on average).

3. Heterogeneous Treatment Effect (τ(X)\tau(X))

The treatment effect on the link (log-mean) scale is: τ(X)=0.08+0.08tanh(ln1+X26)+0.11X8+0.03X9+0.03X110.07tanhX1480.05tanhX44015+0.03X10tanh(ln1+X361)0.04tanh(ln1+X72.5)\tau(X) = 0.08 + 0.08 \tanh\left(\ln\frac{1+X_2}{6}\right) + 0.11 X_8 + 0.03 X_9 + 0.03 X_{11} - 0.07 \tanh\frac{X_1}{48} - 0.05 \tanh\frac{X_4-40}{15} + 0.03 X_{10} \tanh\left(\ln\frac{1+X_3}{61}\right) - 0.04 \tanh\left(\ln\frac{1+X_7}{2.5}\right) and is clipped to [0.005,0.35][0.005, 0.35].

Important: this clipping is for τ(X)\tau(X) on the link scale, not for CATE itself.

4. Outcome Model (YY)

The outcome follows a Tweedie (Two-part Hurdle) model:

  1. Zero-Inflation (Participation): Y>0Y > 0 with probability ppos=σ(αzi+βzi,jXj+gzi(X)+Dτzi(X))p_{pos} = \sigma(\alpha_{zi} + \sum \beta_{zi,j} X_j + g_{zi}(X) + D \cdot \tau_{zi}(X)) where αzi=0.8\alpha_{zi} = -0.8, and the nonlinear components are: gzi(X)=0.45tanh(ln1+X25)+0.25tanh(ln1+X351)+0.20ln(1+X6)0.30tanh(ln1+X72)+0.15X8+0.10X9g_{zi}(X) = 0.45 \tanh\left(\ln\frac{1+X_2}{5}\right) + 0.25 \tanh\left(\ln\frac{1+X_3}{51}\right) + 0.20 \ln(1+X_6) - 0.30 \tanh\left(\ln\frac{1+X_7}{2}\right) + 0.15 X_8 + 0.10 X_9 τzi(X)=0.03+0.02X8+0.015X110.02tanhX148\tau_{zi}(X) = 0.03 + 0.02 X_8 + 0.015 X_{11} - 0.02 \tanh\frac{X_1}{48}

  2. Positive Outcome (Magnitude): If Y>0Y > 0, then YGamma(shape=2.2,scale=μpos/2.2)Y \sim \text{Gamma}(\text{shape} = 2.2, \text{scale} = \mu_{pos} / 2.2) where ln(μpos)=βy,jXj+gy(X)+Dτ(X)\ln(\mu_{pos}) = \sum \beta_{y,j} X_j + g_y(X) + D \cdot \tau(X). The baseline nonlinear part is: gy(X)=1.4tanhX124+0.6(ln1+X26)2+0.25ln1+X361ln1+X26+0.35tanhln(1+X5)ln40011.50.45ln(1+X7)+0.20ln(1+X6)tanhX118+0.30X9ln1+X26+0.25X8(X100.5)+0.15X11tanh(X1121)0.20tanhX44015g_y(X) = 1.4 \tanh\frac{X_1}{24} + 0.6 \left(\ln\frac{1+X_2}{6}\right)^2 + 0.25 \ln\frac{1+X_3}{61} \ln\frac{1+X_2}{6} + 0.35 \tanh\frac{\ln(1+X_5) - \ln 4001}{1.5} - 0.45 \ln(1+X_7) + 0.20 \ln(1+X_6) \tanh\frac{X_1}{18} + 0.30 X_9 \ln\frac{1+X_2}{6} + 0.25 X_8 (X_{10}-0.5) + 0.15 X_{11} \tanh(\frac{X_1}{12}-1) - 0.20 \tanh\frac{X_4-40}{15} The oracle natural-scale CATE is always computed as CATE(X)=g1(X)g0(X),\mathrm{CATE}(X)=g_1(X)-g_0(X), where gd(X)=E[YX,D=d]g_d(X)=\mathbb{E}[Y\mid X,D=d] under the two-part model.

Result
ydtenure_monthsavg_sessions_weekspend_last_monthage_yearsincome_monthlyprior_purchases_12msupport_tickets_90dpremium_usermobile_userurban_residentreferred_usermm_obstau_linkg0g1cate
00.0000000.028.8146541.077.93676750.2341011926.6983011.02.01.01.01.00.00.0454530.0454530.0890958.1379819.1423951.004414
180.0996111.025.9133453.053.77774028.1158595104.2715093.00.01.01.00.01.00.0415140.0415140.24667960.45925778.81730718.358049
26.4004821.024.96992910.0134.76432222.9070625267.9382558.03.00.01.01.00.00.0525930.0525930.1629687.7128559.1385771.425723
32.7882380.040.6550895.059.51707431.9704906597.3270183.02.01.01.01.00.00.0362210.0362210.18875525.38651031.1599325.773422
40.0000000.018.5608993.074.37093039.2372484930.0096285.01.01.01.00.00.00.0363430.0363430.17475715.35925018.6002273.240977
Result

Ground truth ATE is 19.409586529660793 Ground truth ATTE is 10.914991423363865

Result

CausalData(df=(100000, 13), treatment='d', outcome='y', confounders=['tenure_months', 'avg_sessions_week', 'spend_last_month', 'age_years', 'income_monthly', 'prior_purchases_12m', 'support_tickets_90d', 'premium_user', 'mobile_user', 'urban_resident', 'referred_user'])

Result
treatmentcountmeanstdminp10p25medianp75p90max
00.09505176.087138240.8007130.00.00.08.3954464.859278190.22790021396.007575
11.0494958.506172199.4856250.00.00.00.0000036.958280148.8371935143.642132
Result

png

Result

png

Result
treatmentnoutlier_countoutlier_ratelower_boundupper_boundhas_outliersmethodtail
00.095051113000.118884-97.288916162.148194Trueiqrboth
11.049497210.145686-55.43742092.395699Trueiqrboth
Result
confoundersmean_d_0mean_d_1abs_diffsmdks_pvalue
0premium_user0.7518070.5918370.159970-0.3457210.00000
1income_monthly4549.3851903918.058798631.326392-0.2776110.00000
2spend_last_month89.09180167.37538921.716412-0.2683600.00000
3support_tickets_90d0.9845451.2592440.2746990.2539740.00000
4avg_sessions_week5.0477534.2301480.817606-0.2017350.00000
5prior_purchases_12m3.9042203.5136390.390581-0.1893720.00000
6tenure_months28.74010025.5591613.180939-0.1841560.00000
7age_years36.43598434.8090831.626901-0.1441420.00000
8referred_user0.2714860.3071330.0356470.0786710.00001
9urban_resident0.6007930.5688020.031991-0.0649540.00013
10mobile_user0.8745730.8700750.004498-0.0134770.99998