Hi all,
I've used Mata to code up a matching estimator for treatment effect estimation in a multinomial treatment setting. My goal is to reduce the runtime of a simulation of 1000 replications by cutting down the time it takes for the estimator to create these matches. Someone suggested trying the user-written -parallel- command, however, I've not been able to get it to work. Here's code that shows what 1 replication of my simulation produces, where I'm specifically trying to speed up implementation fo the match() function:
Thanks in advance!
Jessica
I've used Mata to code up a matching estimator for treatment effect estimation in a multinomial treatment setting. My goal is to reduce the runtime of a simulation of 1000 replications by cutting down the time it takes for the estimator to create these matches. Someone suggested trying the user-written -parallel- command, however, I've not been able to get it to work. Here's code that shows what 1 replication of my simulation produces, where I'm specifically trying to speed up implementation fo the match() function:
Code:
clear all set seed 123456 set obs 10000 ******************************************************************************** ******************************* METHOD: VBKW ******************************** *Vector-Based Kernel weighting: This is a hybrid method that matches subjects *based on similar propensity score vectors (all propensity scores). For example, *treat = 1 subjects are matched to treat = 2 subjects on the logit of ps1 *within a bandwidth of .2*standard deviation of the logit of ps1, based on *the Epanechnikov kernel function. Treat = 1 subjects are matched to treat = 2 *subjects only if subjects have similar values of logit(ps1), logit(ps2), *and logit(ps3) within the specified bandwidth. ******************************************************************************** set matastrict on mata: mata clear mata: void match(string viewvars, string strategy, real scalar N_ref, real scalar N_comp, real scalar bwidth, real scalar bwidth2, real scalar bwidth3) { real scalar Nobs, pscore_ref, pscore_ref2, pscore_ref3, i, j real colvector pscore_tref, pscore_tref2, pscore_tref3, pscore_tcomp, pscore_tcomp2, pscore_tcomp3, dif, dif2, dif3, weight, _y real matrix X Nobs = pscore_ref = pscore_ref2 = pscore_ref3 = i = j = . pscore_tref = pscore_tref2 = pscore_tref3 = pscore_tcomp = pscore_tcomp2 = pscore_tcomp3 = dif = dif2 = dif3 = weight = _y = . if (strategy == "kw") { st_view(X = ., ., tokens(viewvars), 0) } else if (strategy == "vbkw" & tokens(viewvars)[1] == "ps1") { st_view(X = ., ., tokens(viewvars + " ps2" + " ps3"), 0) } else if (strategy == "vbkw" & tokens(viewvars)[1] == "ps2") { st_view(X = ., ., tokens(viewvars + " ps1" + " ps3"), 0) } else if (strategy == "vbkw" & tokens(viewvars)[1] == "ps3") { st_view(X = ., ., tokens(viewvars + " ps1" + " ps2"), 0) } Nobs = rows(X) pscore_tref = X[(1..N_ref), 1] pscore_tcomp = X[((N_ref + 1)..Nobs), 1] if (strategy == "vbkw") { pscore_tref2 = X[(1..N_ref), 7] pscore_tref3 = X[(1..N_ref), 8] pscore_tcomp2 = X[((N_ref + 1)..Nobs), 7] pscore_tcomp3 = X[((N_ref + 1)..Nobs), 8] } for (i = 1; i <= N_ref; i = i + 1) { pscore_ref = pscore_tref[i, 1] dif = abs(pscore_tcomp :- pscore_ref) if (strategy == "vbkw") { pscore_ref2 = pscore_tref2[i, 1] pscore_ref3 = pscore_tref3[i, 1] dif2 = abs(pscore_tcomp2 :- pscore_ref2) dif3 = abs(pscore_tcomp3 :- pscore_ref3) } weight = J(N_comp, 1, .) weight = (3/4) :* (1 :- (dif :/ bwidth) :^2) if (strategy == "kw"){ dif = dif :<= bwidth weight = weight :* dif X[((N_ref + 1)..Nobs), 6] = X[((N_ref + 1)..Nobs), 6] :+ dif X[i, 6] = colsum(dif) } else { dif = dif :<= bwidth dif2 = dif2 :<= bwidth2 dif3 = dif3 :<= bwidth3 dif = dif :* dif2 :* dif3 weight = weight :* dif X[((N_ref + 1)..Nobs), 6] = X[((N_ref + 1)..Nobs), 6] :+ dif X[i, 6] = colsum(dif) } if (mean(weight) == 0) X[i, 2] = 0 if (mean(weight) == 0) X[i, 4] = 0 if (mean(weight) != 0) weight = weight / sum(weight) X[((N_ref + 1)..Nobs), 4] = X[((N_ref + 1)..Nobs), 4] :+ weight _y = X[((N_ref + 1)..Nobs), 5] :* weight _editvalue(_y, 0, .) if (mean(weight) == 0) X[i, 3] = . if (mean(weight) != 0) X[i, 3] = sum(_y) } } end ******************************************************************************** /*Generate covariates:*/ matrix M1 = (1,.2\.2, 1) drawnorm x1c x5c,corr(M1) double drawnorm x3c x8c,corr(M1) double matrix M2 = (1,.9\.9, 1) drawnorm x2 x6c,corr(M2) double drawnorm x4 x9c,corr(M2) double gen x7 = rnormal() gen x10 = rnormal() sum x1c gen double x1 = 0 replace x1 = 1 if x1c > r(mean) sum x3c gen double x3 = 0 replace x3 = 1 if x3c > r(mean) sum x5c gen double x5 = 0 replace x5 = 1 if x5c > r(mean) sum x6c gen double x6 = 0 replace x6 = 1 if x6c > r(mean) sum x8c gen double x8 = 0 replace x8 = 1 if x8c > r(mean) sum x9c gen double x9 = 0 replace x9 = 1 if x9c > r(mean) gen double xb2 = -4 + .2*(x1 + x2 + x3 + x4 + x5 + x6 + x7) gen double xb3 = -1 + -.9*(x1 + x2 + x3 + x4 + x5 + x6 + x7) gen double p1 = 1 / ( 1 + exp(xb2) + exp(xb3) ) gen double p2 = exp(xb2) / ( 1 + exp(xb2) + exp(xb3) ) gen double p3 = exp(xb3) / ( 1 + exp(xb2) + exp(xb3) ) *Assign treatment groups based on Setoguchi et al. (2008) gen double tmtprob = runiform() gen treat = 3 replace treat = 2 if tmtprob < p1 + p2 replace treat = 1 if tmtprob < p1 ******************************************************************************** *Generate true outcomes *Treatment effect distribution hom = homogeneous treatment effect *Treatment effect distribution hetx10 = heterogeneous treatment effect, dependent on *x10, which is a variable only associated with the outcome and not associated *with treatment group assignment. g double trueY1 = -.5 + .2*(x1 + x2 + x3 + x4 + x8 + x9 + x10) g double trueY2 = -.4 + .2*(x1 + x2 + x3 + x4 + x8 + x9 + x10) g double trueY3 = -.3 + .2*(x1 + x2 + x3 + x4 + x8 + x9 + x10) ******************************************************************************** *generate true Average Treatment Effects g double trueATE12 = trueY1-trueY2 if treat !=. g double trueATE13 = trueY1-trueY3 if treat !=. g double trueATE23 = trueY2-trueY3 if treat !=. ******************************************************************************** *generate true Average Treatment Effects on the treated g double trueATT12_1 = trueY1-trueY2 if treat == 1 g double trueATT12_2 = trueY1-trueY2 if treat == 2 g double trueATT12_3 = trueY1-trueY2 if treat == 3 g double trueATT13_1 = trueY1-trueY3 if treat == 1 g double trueATT13_2 = trueY1-trueY3 if treat == 2 g double trueATT13_3 = trueY1-trueY3 if treat == 3 g double trueATT23_1 = trueY2-trueY3 if treat == 1 g double trueATT23_2 = trueY2-trueY3 if treat == 2 g double trueATT23_3 = trueY2-trueY3 if treat == 3 ******************************************************************************** *draw random sample without replacement from starting population. *sample distribution across treatment groups: foreach t of numlist 1 2 3{ randomtag if treat == `t', count(400) g(treat_`t') replace treat = . if treat_`t' == 0 & treat==`t' } ******************************************************************************** *observations in the population that were not drawn into the random sample are dropped: egen missing = rowmiss(treat) drop if missing == 1 ******************************************************************************** *generate observed outcome variable: g double trueY = cond(treat == 1, trueY1, cond(treat == 2, trueY2, cond(treat == 3, trueY3, .))) /*MLOGIT*/ mlogit treat x1 x2 x3 x4 x5 x6 x8 x9 x10, baseoutcome(1) predict double ps1 ps2 ps3 if e(sample), p compress *Also see Lopez and Gutman (2017), section 2.0.1 Estimands and common support (12) and (13) g byte _support = 1 sum ps1 if treat == 1 scalar min1 = r(min) scalar max1 = r(max) sum ps1 if treat == 2 scalar min2 = r(min) scalar max2 = r(max) sum ps1 if treat == 3 scalar min3 = r(min) scalar max3 = r(max) scalar maxofmin = max(min1, min2, min3) scalar minofmax = min(max1, max2, max3) replace _support = 0 if ps1 < maxofmin | ps1 > minofmax sum ps2 if treat == 1 scalar min1 = r(min) scalar max1 = r(max) sum ps2 if treat == 2 scalar min2 = r(min) scalar max2 = r(max) sum ps2 if treat == 3 scalar min3 = r(min) scalar max3 = r(max) scalar maxofmin = max(min1, min2, min3) scalar minofmax = min(max1, max2, max3) replace _support = 0 if ps2 < maxofmin | ps2 > minofmax sum ps3 if treat == 1 scalar min1 = r(min) scalar max1 = r(max) sum ps3 if treat == 2 scalar min2 = r(min) scalar max2 = r(max) sum ps3 if treat == 3 scalar min3 = r(min) scalar max3 = r(max) scalar maxofmin = max(min1, min2, min3) scalar minofmax = min(max1, max2, max3) replace _support = 0 if ps3 < maxofmin | ps3 > minofmax *Drop units outside the common support keep if _support == 1 g id = _n compress ******************************************************************************** *generate vars needed for pairwise comparisons of t1 and t2: cap drop treat1 gen treat1= 1 if treat==1 replace treat1 = 0 if treat==2 capture drop *AT*support* *AT*weight* *AT*trueY* gen byte _ATTsupport = . replace _ATTsupport = (treat1 <=1) gen double _ATTtrueY = 0 if _ATTsupport //ATT matched outcome gen byte _ATUsupport = . replace _ATUsupport = (treat1 <=1) gen double _ATUtrueY = 0 if _ATUsupport //ATU matched outcome gen double _ATUweight = 1 - treat1 if _ATUsupport == 1 gen double _ATTweight = treat1 if _ATTsupport == 1 ******************************************************************************** *VBKW ATT 1 v 2 | T = 2 estimation: cap drop n_used g double n_used = 0 if treat1 !=. sort treat1 id local varlist ps2 _ATUsupport _ATUtrueY _ATUweight trueY n_used local method vbkw count if treat == 2 local nref = r(N) count if treat == 1 local ncomp = r(N) sum ps2 local bandwidth = .2*r(sd) sum ps1 local bandwidth2 = .2*r(sd) sum ps3 local bandwidth3 = .2*r(sd) mata: match("`varlist'", "`method'", `nref', `ncomp', `bandwidth', `bandwidth2', `bandwidth3') replace _ATUsupport = 0 if _ATUweight == 0 | _ATUweight == . *ATT 1 v 2 | T = 2: sum _ATUtrueY if treat == 2 & _ATUsupport == 1 scalar m1u = r(mean) local N2 = r(N) sum trueY if treat == 2 & _ATUsupport == 1 scalar m2u = r(mean) scalar k_ATU = m1u - m2u g double vbkw_ATT12_2 = k_ATU ******************************************************************************** *VBKW ATT 1 v 2 | T = 1 estimation: cap drop n_used g double n_used = 0 if treat1 !=. gsort -treat1 id local varlist ps1 _ATTsupport _ATTtrueY _ATTweight trueY n_used local method vbkw count if treat == 1 local nref = r(N) count if treat == 2 local ncomp = r(N) sum ps1 local bandwidth = .2*r(sd) sum ps2 local bandwidth2 = .2*r(sd) sum ps3 local bandwidth3 = .2*r(sd) mata: match("`varlist'", "`method'", `nref', `ncomp', `bandwidth', `bandwidth2', `bandwidth3') replace _ATTsupport = 0 if _ATTweight == 0 | _ATTweight == . *ATT 1 v 2 | T = 1: sum trueY if treat == 1 & _ATTsupport == 1 scalar m1t = r(mean) local N1 = r(N) sum _ATTtrueY if treat == 1 & _ATTsupport == 1 scalar m2t = r(mean) scalar k_ATT = m1t - m2t g double vbkw_ATT12_1 = k_ATT g double vbkw_ATE12 = (k_ATT*`N1'/(`N1'+`N2')) + (k_ATU*`N2'/(`N1'+`N2')) // (E[Y1 - Y2 | t = 1]*N1 + E[Y1 - Y2 | t = 2]*N2) / (N1 + N2) gen double _ATEweight = _ATTweight + _ATUweight if treat1 <= 1
Jessica
Comment