Hi all,
I've used Mata to code up a matching estimator for treatment effect estimation in a multinomial treatment setting. My goal is to reduce the runtime of a simulation of 1000 replications by cutting down the time it takes for the estimator to create these matches. Someone suggested trying the user-written -parallel- command, however, I've not been able to get it to work. Here's code that shows what 1 replication of my simulation produces, where I'm specifically trying to speed up implementation fo the match() function:
Thanks in advance!
Jessica
I've used Mata to code up a matching estimator for treatment effect estimation in a multinomial treatment setting. My goal is to reduce the runtime of a simulation of 1000 replications by cutting down the time it takes for the estimator to create these matches. Someone suggested trying the user-written -parallel- command, however, I've not been able to get it to work. Here's code that shows what 1 replication of my simulation produces, where I'm specifically trying to speed up implementation fo the match() function:
Code:
clear all
set seed 123456
set obs 10000
********************************************************************************
******************************* METHOD: VBKW ********************************
*Vector-Based Kernel weighting: This is a hybrid method that matches subjects
*based on similar propensity score vectors (all propensity scores). For example,
*treat = 1 subjects are matched to treat = 2 subjects on the logit of ps1
*within a bandwidth of .2*standard deviation of the logit of ps1, based on
*the Epanechnikov kernel function. Treat = 1 subjects are matched to treat = 2
*subjects only if subjects have similar values of logit(ps1), logit(ps2),
*and logit(ps3) within the specified bandwidth.
********************************************************************************
set matastrict on
mata: mata clear
mata:
void match(string viewvars, string strategy, real scalar N_ref, real scalar N_comp, real scalar bwidth, real scalar bwidth2, real scalar bwidth3)
{
real scalar Nobs, pscore_ref, pscore_ref2, pscore_ref3, i, j
real colvector pscore_tref, pscore_tref2, pscore_tref3, pscore_tcomp, pscore_tcomp2, pscore_tcomp3, dif, dif2, dif3, weight, _y
real matrix X
Nobs = pscore_ref = pscore_ref2 = pscore_ref3 = i = j = .
pscore_tref = pscore_tref2 = pscore_tref3 = pscore_tcomp = pscore_tcomp2 = pscore_tcomp3 = dif = dif2 = dif3 = weight = _y = .
if (strategy == "kw")
{
st_view(X = ., ., tokens(viewvars), 0)
}
else if (strategy == "vbkw" & tokens(viewvars)[1] == "ps1") {
st_view(X = ., ., tokens(viewvars + " ps2" + " ps3"), 0)
}
else if (strategy == "vbkw" & tokens(viewvars)[1] == "ps2") {
st_view(X = ., ., tokens(viewvars + " ps1" + " ps3"), 0)
}
else if (strategy == "vbkw" & tokens(viewvars)[1] == "ps3") {
st_view(X = ., ., tokens(viewvars + " ps1" + " ps2"), 0)
}
Nobs = rows(X)
pscore_tref = X[(1..N_ref), 1]
pscore_tcomp = X[((N_ref + 1)..Nobs), 1]
if (strategy == "vbkw") {
pscore_tref2 = X[(1..N_ref), 7]
pscore_tref3 = X[(1..N_ref), 8]
pscore_tcomp2 = X[((N_ref + 1)..Nobs), 7]
pscore_tcomp3 = X[((N_ref + 1)..Nobs), 8]
}
for (i = 1; i <= N_ref; i = i + 1) {
pscore_ref = pscore_tref[i, 1]
dif = abs(pscore_tcomp :- pscore_ref)
if (strategy == "vbkw") {
pscore_ref2 = pscore_tref2[i, 1]
pscore_ref3 = pscore_tref3[i, 1]
dif2 = abs(pscore_tcomp2 :- pscore_ref2)
dif3 = abs(pscore_tcomp3 :- pscore_ref3)
}
weight = J(N_comp, 1, .)
weight = (3/4) :* (1 :- (dif :/ bwidth) :^2)
if (strategy == "kw"){
dif = dif :<= bwidth
weight = weight :* dif
X[((N_ref + 1)..Nobs), 6] = X[((N_ref + 1)..Nobs), 6] :+ dif
X[i, 6] = colsum(dif)
}
else {
dif = dif :<= bwidth
dif2 = dif2 :<= bwidth2
dif3 = dif3 :<= bwidth3
dif = dif :* dif2 :* dif3
weight = weight :* dif
X[((N_ref + 1)..Nobs), 6] = X[((N_ref + 1)..Nobs), 6] :+ dif
X[i, 6] = colsum(dif)
}
if (mean(weight) == 0) X[i, 2] = 0
if (mean(weight) == 0) X[i, 4] = 0
if (mean(weight) != 0) weight = weight / sum(weight)
X[((N_ref + 1)..Nobs), 4] = X[((N_ref + 1)..Nobs), 4] :+ weight
_y = X[((N_ref + 1)..Nobs), 5] :* weight
_editvalue(_y, 0, .)
if (mean(weight) == 0) X[i, 3] = .
if (mean(weight) != 0) X[i, 3] = sum(_y)
}
}
end
********************************************************************************
/*Generate covariates:*/
matrix M1 = (1,.2\.2, 1)
drawnorm x1c x5c,corr(M1) double
drawnorm x3c x8c,corr(M1) double
matrix M2 = (1,.9\.9, 1)
drawnorm x2 x6c,corr(M2) double
drawnorm x4 x9c,corr(M2) double
gen x7 = rnormal()
gen x10 = rnormal()
sum x1c
gen double x1 = 0
replace x1 = 1 if x1c > r(mean)
sum x3c
gen double x3 = 0
replace x3 = 1 if x3c > r(mean)
sum x5c
gen double x5 = 0
replace x5 = 1 if x5c > r(mean)
sum x6c
gen double x6 = 0
replace x6 = 1 if x6c > r(mean)
sum x8c
gen double x8 = 0
replace x8 = 1 if x8c > r(mean)
sum x9c
gen double x9 = 0
replace x9 = 1 if x9c > r(mean)
gen double xb2 = -4 + .2*(x1 + x2 + x3 + x4 + x5 + x6 + x7)
gen double xb3 = -1 + -.9*(x1 + x2 + x3 + x4 + x5 + x6 + x7)
gen double p1 = 1 / ( 1 + exp(xb2) + exp(xb3) )
gen double p2 = exp(xb2) / ( 1 + exp(xb2) + exp(xb3) )
gen double p3 = exp(xb3) / ( 1 + exp(xb2) + exp(xb3) )
*Assign treatment groups based on Setoguchi et al. (2008)
gen double tmtprob = runiform()
gen treat = 3
replace treat = 2 if tmtprob < p1 + p2
replace treat = 1 if tmtprob < p1
********************************************************************************
*Generate true outcomes
*Treatment effect distribution hom = homogeneous treatment effect
*Treatment effect distribution hetx10 = heterogeneous treatment effect, dependent on
*x10, which is a variable only associated with the outcome and not associated
*with treatment group assignment.
g double trueY1 = -.5 + .2*(x1 + x2 + x3 + x4 + x8 + x9 + x10)
g double trueY2 = -.4 + .2*(x1 + x2 + x3 + x4 + x8 + x9 + x10)
g double trueY3 = -.3 + .2*(x1 + x2 + x3 + x4 + x8 + x9 + x10)
********************************************************************************
*generate true Average Treatment Effects
g double trueATE12 = trueY1-trueY2 if treat !=.
g double trueATE13 = trueY1-trueY3 if treat !=.
g double trueATE23 = trueY2-trueY3 if treat !=.
********************************************************************************
*generate true Average Treatment Effects on the treated
g double trueATT12_1 = trueY1-trueY2 if treat == 1
g double trueATT12_2 = trueY1-trueY2 if treat == 2
g double trueATT12_3 = trueY1-trueY2 if treat == 3
g double trueATT13_1 = trueY1-trueY3 if treat == 1
g double trueATT13_2 = trueY1-trueY3 if treat == 2
g double trueATT13_3 = trueY1-trueY3 if treat == 3
g double trueATT23_1 = trueY2-trueY3 if treat == 1
g double trueATT23_2 = trueY2-trueY3 if treat == 2
g double trueATT23_3 = trueY2-trueY3 if treat == 3
********************************************************************************
*draw random sample without replacement from starting population.
*sample distribution across treatment groups:
foreach t of numlist 1 2 3{
randomtag if treat == `t', count(400) g(treat_`t')
replace treat = . if treat_`t' == 0 & treat==`t'
}
********************************************************************************
*observations in the population that were not drawn into the random sample are dropped:
egen missing = rowmiss(treat)
drop if missing == 1
********************************************************************************
*generate observed outcome variable:
g double trueY = cond(treat == 1, trueY1, cond(treat == 2, trueY2, cond(treat == 3, trueY3, .)))
/*MLOGIT*/
mlogit treat x1 x2 x3 x4 x5 x6 x8 x9 x10, baseoutcome(1)
predict double ps1 ps2 ps3 if e(sample), p
compress
*Also see Lopez and Gutman (2017), section 2.0.1 Estimands and common support (12) and (13)
g byte _support = 1
sum ps1 if treat == 1
scalar min1 = r(min)
scalar max1 = r(max)
sum ps1 if treat == 2
scalar min2 = r(min)
scalar max2 = r(max)
sum ps1 if treat == 3
scalar min3 = r(min)
scalar max3 = r(max)
scalar maxofmin = max(min1, min2, min3)
scalar minofmax = min(max1, max2, max3)
replace _support = 0 if ps1 < maxofmin | ps1 > minofmax
sum ps2 if treat == 1
scalar min1 = r(min)
scalar max1 = r(max)
sum ps2 if treat == 2
scalar min2 = r(min)
scalar max2 = r(max)
sum ps2 if treat == 3
scalar min3 = r(min)
scalar max3 = r(max)
scalar maxofmin = max(min1, min2, min3)
scalar minofmax = min(max1, max2, max3)
replace _support = 0 if ps2 < maxofmin | ps2 > minofmax
sum ps3 if treat == 1
scalar min1 = r(min)
scalar max1 = r(max)
sum ps3 if treat == 2
scalar min2 = r(min)
scalar max2 = r(max)
sum ps3 if treat == 3
scalar min3 = r(min)
scalar max3 = r(max)
scalar maxofmin = max(min1, min2, min3)
scalar minofmax = min(max1, max2, max3)
replace _support = 0 if ps3 < maxofmin | ps3 > minofmax
*Drop units outside the common support
keep if _support == 1
g id = _n
compress
********************************************************************************
*generate vars needed for pairwise comparisons of t1 and t2:
cap drop treat1
gen treat1= 1 if treat==1
replace treat1 = 0 if treat==2
capture drop *AT*support* *AT*weight* *AT*trueY*
gen byte _ATTsupport = .
replace _ATTsupport = (treat1 <=1)
gen double _ATTtrueY = 0 if _ATTsupport //ATT matched outcome
gen byte _ATUsupport = .
replace _ATUsupport = (treat1 <=1)
gen double _ATUtrueY = 0 if _ATUsupport //ATU matched outcome
gen double _ATUweight = 1 - treat1 if _ATUsupport == 1
gen double _ATTweight = treat1 if _ATTsupport == 1
********************************************************************************
*VBKW ATT 1 v 2 | T = 2 estimation:
cap drop n_used
g double n_used = 0 if treat1 !=.
sort treat1 id
local varlist ps2 _ATUsupport _ATUtrueY _ATUweight trueY n_used
local method vbkw
count if treat == 2
local nref = r(N)
count if treat == 1
local ncomp = r(N)
sum ps2
local bandwidth = .2*r(sd)
sum ps1
local bandwidth2 = .2*r(sd)
sum ps3
local bandwidth3 = .2*r(sd)
mata: match("`varlist'", "`method'", `nref', `ncomp', `bandwidth', `bandwidth2', `bandwidth3')
replace _ATUsupport = 0 if _ATUweight == 0 | _ATUweight == .
*ATT 1 v 2 | T = 2:
sum _ATUtrueY if treat == 2 & _ATUsupport == 1
scalar m1u = r(mean)
local N2 = r(N)
sum trueY if treat == 2 & _ATUsupport == 1
scalar m2u = r(mean)
scalar k_ATU = m1u - m2u
g double vbkw_ATT12_2 = k_ATU
********************************************************************************
*VBKW ATT 1 v 2 | T = 1 estimation:
cap drop n_used
g double n_used = 0 if treat1 !=.
gsort -treat1 id
local varlist ps1 _ATTsupport _ATTtrueY _ATTweight trueY n_used
local method vbkw
count if treat == 1
local nref = r(N)
count if treat == 2
local ncomp = r(N)
sum ps1
local bandwidth = .2*r(sd)
sum ps2
local bandwidth2 = .2*r(sd)
sum ps3
local bandwidth3 = .2*r(sd)
mata: match("`varlist'", "`method'", `nref', `ncomp', `bandwidth', `bandwidth2', `bandwidth3')
replace _ATTsupport = 0 if _ATTweight == 0 | _ATTweight == .
*ATT 1 v 2 | T = 1:
sum trueY if treat == 1 & _ATTsupport == 1
scalar m1t = r(mean)
local N1 = r(N)
sum _ATTtrueY if treat == 1 & _ATTsupport == 1
scalar m2t = r(mean)
scalar k_ATT = m1t - m2t
g double vbkw_ATT12_1 = k_ATT
g double vbkw_ATE12 = (k_ATT*`N1'/(`N1'+`N2')) + (k_ATU*`N2'/(`N1'+`N2')) // (E[Y1 - Y2 | t = 1]*N1 + E[Y1 - Y2 | t = 2]*N2) / (N1 + N2)
gen double _ATEweight = _ATTweight + _ATUweight if treat1 <= 1
Jessica

Comment