24.7 Case study: t-test
The following case study shows how to make t-tests faster using some of the techniques described above. It’s based on an example in Computing thousands of test statistics simultaneously in R by Holger Schwender and Tina Müller. I thoroughly recommend reading the paper in full to see the same idea applied to other tests.
Imagine we have run 1000 experiments (rows), each of which collects data on 50 individuals (columns). The first 25 individuals in each experiment are assigned to group 1 and the rest to group 2. We’ll first generate some random data to represent this problem:
1000
m <- 50
n <- matrix(rnorm(m * n, mean = 10, sd = 3), nrow = m)
X <- rep(1:2, each = n / 2) grp <-
For data in this form, there are two ways to use t.test()
. We can either use the formula interface or provide two vectors, one for each group. Timing reveals that the formula interface is considerably slower.
system.time(
for (i in 1:m) {
t.test(X[i, ] ~ grp)$statistic
}
)#> user system elapsed
#> 0.667 0.000 0.667
system.time(
for (i in 1:m) {
t.test(X[i, grp == 1], X[i, grp == 2])$statistic
}
)#> user system elapsed
#> 0.132 0.000 0.133
Of course, a for loop computes, but doesn’t save the values. We can map_dbl()
(Section 9.2.1) to do that. This adds a little overhead:
function(i){
compT <-t.test(X[i, grp == 1], X[i, grp == 2])$statistic
}system.time(t1 <- purrr::map_dbl(1:m, compT))
#> user system elapsed
#> 0.136 0.000 0.136
How can we make this faster? First, we could try doing less work. If you look at the source code of stats:::t.test.default()
, you’ll see that it does a lot more than just compute the t-statistic. It also computes the p-value and formats the output for printing. We can try to make our code faster by stripping out those pieces.
function(x, grp) {
my_t <- function(x) {
t_stat <- mean(x)
m <- length(x)
n <- sum((x - m) ^ 2) / (n - 1)
var <-
list(m = m, n = n, var = var)
}
t_stat(x[grp == 1])
g1 <- t_stat(x[grp == 2])
g2 <-
sqrt(g1$var / g1$n + g2$var / g2$n)
se_total <-$m - g2$m) / se_total
(g1
}
system.time(t2 <- purrr::map_dbl(1:m, ~ my_t(X[.,], grp)))
#> user system elapsed
#> 0.025 0.000 0.025
stopifnot(all.equal(t1, t2))
This gives us about a six-fold speed improvement.
Now that we have a fairly simple function, we can make it faster still by vectorising it. Instead of looping over the array outside the function, we will modify t_stat()
to work with a matrix of values. Thus, mean()
becomes rowMeans()
, length()
becomes ncol()
, and sum()
becomes rowSums()
. The rest of the code stays the same.
function(X, grp){
rowtstat <- function(X) {
t_stat <- rowMeans(X)
m <- ncol(X)
n <- rowSums((X - m) ^ 2) / (n - 1)
var <-
list(m = m, n = n, var = var)
}
t_stat(X[, grp == 1])
g1 <- t_stat(X[, grp == 2])
g2 <-
sqrt(g1$var / g1$n + g2$var / g2$n)
se_total <-$m - g2$m) / se_total
(g1
}system.time(t3 <- rowtstat(X, grp))
#> user system elapsed
#> 0.011 0.000 0.011
stopifnot(all.equal(t1, t3))
That’s much faster! It’s at least 40 times faster than our previous effort, and around 1000 times faster than where we started.