学习R语言:性能提升——速度和内存

目录

本文内容来自《R 语言编程艺术》(The Art of R Programming),有部分修改

时间和空间的权衡

编写快速 R 代码

  • 向量化,字节码编译,其他方法
  • 核心部分用编译型语言编写,如 C/C++(后续文章介绍)
  • 并行(后续文章介绍)

可怕的 for 循环

用向量化提升速度

示例1

x <- runif(10000000)
y <- runif(10000000)
z <- vector(length=10000000)

向量化版本

system.time(z <- x + y)
   user  system elapsed 
   0.01    0.01    0.03 

for 循环

注意:R 中每个操作都是函数

system.time(for (i in 1:length(x)) z[i] <- x[i] + y[i])
   user  system elapsed 
   0.84    0.00    0.85 

示例2

向量过滤?

oddcount <- function(x) return(sum(x%%2 == 1))
x <- sample(1:10000000, 10000000, replace=TRUE)

向量化版本

system.time(oddcount(x))
   user  system elapsed 
   0.22    0.00    0.22 

for 版本

system.time(
  {
    c <- 0
    for(i in 1:length(x)) {
      if (x[1]%%2 == 1) c <- c + 1
    }
  }
)
   user  system elapsed 
   2.13    0.00    2.13 

向量化函数举例

  • ifelse()
  • which()
  • where()
  • any()
  • all()
  • cumsum()
  • cumprod()
  • rowSums()
  • colSums()
  • combn()
  • outer()
  • lower.tri()
  • upper.tri()
  • expand.grid()

扩展案例:在蒙特卡罗模拟中获得更快的速度

示例1

for 循环

m_version_1 <- function() {
  sum <- 0
  nreps <- 100000
  for (i in 1:nreps) {
    xy <- rnorm(2)
    sum <- sum + max(xy)
  }
  print(sum/nreps)
}
system.time(m_version_1())
[1] 0.5652596
   user  system elapsed 
   0.25    0.00    0.27 

向量化,空间换时间

m_version_2 <- function() {
  nreps <- 100000
  xymat <- matrix(rnorm(2*nreps), ncol=2)
  maxs <- pmax(xymat[,1], xymat[,2])
  print(mean(maxs))
} 
system.time(m_version_2())
[1] 0.5616987
   user  system elapsed 
   0.02    0.00    0.03

示例2

缸1:10 蓝球,8 黄球

缸2:6 篮球,6 黄球

从缸 1 中随机取 1 个球放到缸 2 中,再从缸 2 中取 1 个球,求第二球为蓝色的概率

for 循环版本

sim1 <- function(nreps) {
  nb1 <- 10
  n1 <- 18
  n2 <- 13
  count <- 0
  for (i in 1:nreps) {
    nb2 <- 6
    if (runif(1) < nb1/n1) nb2 <- nb2 + 1
    if (runif(1) < nb2/n2) count <- count + 1
  }
  return(count/nreps)
}
system.time(print(sim1(100000)))
[1] 0.50497
   user  system elapsed 
   0.32    0.03    0.34 

apply 版本

sim2 <- function(nreps) {
  nb1 <- 10
  nb2 <- 6
  n1 <- 18
  n2 <- 13
  u <- matrix(c(runif(2*nreps)), nrow=nreps, ncol=2)
  
  simfun <- function(rw) {
    if (rw[1] < nb1/n1) nb2 <- nb2 + 1
    return (rw[2] < nb2 / n2)
  }
  z <- apply(u, 1, simfun)
  return(mean(z))
}
[1] 0.49965
   user  system elapsed 
   0.19    0.00    0.20 
system.time(print(sim2(100000)))

可以看到,apply 版本速度提升不明显

向量化版本

sim3 <- function(nreps) {
  nb1 <- 10
  nb2 <- 6
  n1 <- 18
  n2 <- 13
  u <- matrix(c(runif(2*nreps)), nrow=nreps, ncol=2)
  
  cndtn <- u[,1] <= nb1/n1 & u[,2] <= (nb2+1)/n2 |
           u[,1] > nb1/n1 & u[,2] <= nb2/n2
  return(mean(cndtn))
}
system.time(print(sim3(100000)))
[1] 0.50457
   user  system elapsed 
   0.04    0.00    0.03 

向量化版本速度有显著的提升

扩展案例:生成幂次矩阵

cbind() 版本

powers1 <- function(x, dg) {
  pw <- matrix(x, nrow=length(x))
  prod <- x
  for (i in 2:dg) {
    prod <- prod * x
    pw <- cbind(pw, prod)
  }
  return(pw)
}

一次分配所有内存

powers2 <- function(x, dg) {
  pw <- matrix(nrow=length(x), ncol=dg)
  prod <- x
  pw[, 1] <- prod
  for (i in 2:dg) {
    prod <- prod * x
    pw[, i] <- prod
  }
  return(pw)
}

对比运行时间

x <- runif(10000000)
system.time(powers1(x, 8))
   user  system elapsed 
   0.65    0.74    1.40 
system.time(powers2(x, 8))
   user  system elapsed 
   0.43    0.30    0.73 

使用 outer() 函数

outer(X, Y, FUN)FUN 函数应用与 X 和 Y 中元素的所有可能配对上

powers3 <- function(x, dg) {
  return (outer(x, 1:dg, "^"))
}
system.time(powers3(x, 8))
   user  system elapsed 
   3.71    0.27    3.97 

比前两个版本效果更差

使用 cumprod() 函数

powers4 <- function(x, dg) {
  repx <- matrix(rep(x, dg), nrow=length(x))
  return(t(apply(repx, 1, cumprod)))
}
system.time(powers4(x, 8))
   user  system elapsed 
  21.69    1.07   23.38 

效果更糟。

注意:性能有时是不可预测的

函数式编程和内存问题

绝大部分 R 对象都是不可变的。对变量进行重新赋值需要考虑性能问题。

向量赋值问题

rm(z)
z <- 1:10
tracemem(z)
z[3] <- 8
untracemem(z)
[1] "<000001F081FDE578>"
tracemem[0x000001f081fde578 -> 0x000001f093cc7ab8]: 
tracemem[0x000001f093cc7ab8 -> 0x000001f0be9b5a30]: 
z <- "[<-"(z, 3, value=8)

改变时拷贝

z <- runif(10)
tracemem(z)
z[3] <- 8
tracemem(z)
[1] "<000001F0BE38AF10>"
tracemem[0x000001f0be38af10 -> 0x000001f0be28a1b0]: 
[1] "<000001F0BE28A1B0>"
z <- 1:10000000
system.time(z[3] <- 8)
   user  system elapsed 
   0.06    0.00    0.06 
system.time(z[33] <- 88)
   user  system elapsed 
      0       0       0 

扩展案例:避免内存拷贝

m <- 50000
n <- 1000

for 循环

z <- list()
for (i in 1:m) z[[i]] <- sample(1:10, n, replace=TRUE)
system.time(for (i in 1:m) z[[i]][3] <- 8)
   user  system elapsed 
   0.11    0.00    0.11 

向量化

z <- matrix(sample(1:10, m*n, replace=TRUE), nrow=m)
system.time(z[,3] <- 8)
   user  system elapsed 
   0.15    0.00    0.14 

效率与 for 循环相当?

lapply()

set3 <- function(lv) {
  lv[3] <- 8
  return(lv)
}
z <- list()
for (i in 1:m) z[[i]] <- sample(1:10, n, replace=TRUE)
system.time(lapply(z, set3))
   user  system elapsed 
   0.22    0.00    0.21 

效率不如向量化版本

利用 Rprof() 来寻找代码的瓶颈

利用 Rprof() 来进行监视

powers1() 主要瓶颈在 cbind

x <- runif(10000000)
Rprof()
invisible(powers1(x, 8))
Rprof(NULL)
summaryRprof()
$by.self
          self.time self.pct total.time total.pct
"cbind"        0.64    86.49       0.64     86.49
"powers1"      0.08    10.81       0.74    100.00
"matrix"       0.02     2.70       0.02      2.70

$by.total
          total.time total.pct self.time self.pct
"powers1"       0.74    100.00      0.08    10.81
"cbind"         0.64     86.49      0.64    86.49
"matrix"        0.02      2.70      0.02     2.70

$sample.interval
[1] 0.02

$sampling.time
[1] 0.74

powers2() 没有明显的瓶颈

Rprof()
invisible(powers2(x, 8))
Rprof(NULL)
summaryRprof()
$by.self
          self.time self.pct total.time total.pct
"powers2"      0.42     91.3       0.46     100.0
"matrix"       0.04      8.7       0.04       8.7

$by.total
          total.time total.pct self.time self.pct
"powers2"       0.46     100.0      0.42     91.3
"matrix"        0.04       8.7      0.04      8.7

$sample.interval
[1] 0.02

$sampling.time
[1] 0.46

powers3()

Rprof()
invisible(powers3(x, 8))
Rprof(NULL)
summaryRprof()
$by.self
        self.time self.pct total.time total.pct
"outer"      3.12    99.36       3.12     99.36
"c"          0.02     0.64       0.02      0.64

$by.total
          total.time total.pct self.time self.pct
"outer"         3.12     99.36      3.12    99.36
"powers3"       3.12     99.36      0.00     0.00
"c"             0.02      0.64      0.02     0.64
"hook"          0.02      0.64      0.00     0.00
"Rprof"         0.02      0.64      0.00     0.00

$sample.interval
[1] 0.02

$sampling.time
[1] 3.14

Rprof() 的工作原理

每隔 0.02 秒检查一次函数调用栈,将结果写入到一个文件中,默认是 Rprof.out

powers4() 的输出结果很难解读

Rprof()
invisible(powers4(x, 8))
Rprof(NULL)
summaryRprof()
$by.self
                self.time self.pct total.time total.pct
"apply"             12.52    75.42      16.04     96.63
"array"              1.06     6.39       1.06      6.39
"FUN"                0.84     5.06       0.86      5.18
"aperm.default"      0.76     4.58       0.76      4.58
"unlist"             0.62     3.73       0.62      3.73
"t.default"          0.28     1.69       0.28      1.69
"matrix"             0.26     1.57       0.26      1.57
"lengths"            0.18     1.08       0.18      1.08
"tryCatch"           0.02     0.12       0.04      0.24
"any"                0.02     0.12       0.02      0.12
"as.list"            0.02     0.12       0.02      0.12
"grepl"              0.02     0.12       0.02      0.12

$by.total
                       total.time total.pct self.time self.pct
"powers4"                   16.60    100.00      0.00     0.00
"t"                         16.32     98.31      0.00     0.00
"apply"                     16.04     96.63     12.52    75.42
"array"                      1.06      6.39      1.06     6.39
"FUN"                        0.86      5.18      0.84     5.06
"aperm.default"              0.76      4.58      0.76     4.58
"aperm"                      0.76      4.58      0.00     0.00
"unlist"                     0.62      3.73      0.62     3.73
"t.default"                  0.28      1.69      0.28     1.69
"matrix"                     0.26      1.57      0.26     1.57
"lengths"                    0.18      1.08      0.18     1.08
"tryCatch"                   0.04      0.24      0.02     0.12
"any"                        0.02      0.12      0.02     0.12
"as.list"                    0.02      0.12      0.02     0.12
"grepl"                      0.02      0.12      0.02     0.12
"<Anonymous>"                0.02      0.12      0.00     0.00
"base::try"                  0.02      0.12      0.00     0.00
"cmpfun"                     0.02      0.12      0.00     0.00
"compiler:::tryCmpfun"       0.02      0.12      0.00     0.00
"doTryCatch"                 0.02      0.12      0.00     0.00
"findLocalsList"             0.02      0.12      0.00     0.00
"findLocalsList1"            0.02      0.12      0.00     0.00
"funEnv"                     0.02      0.12      0.00     0.00
"lapply"                     0.02      0.12      0.00     0.00
"make.functionContext"       0.02      0.12      0.00     0.00
"tryCatchList"               0.02      0.12      0.00     0.00
"tryCatchOne"                0.02      0.12      0.00     0.00

$sample.interval
[1] 0.02

$sampling.time
[1] 16.6

Rprof.out 文件内容类似

sample.interval=20000
"matrix" "powers4" 
"matrix" "powers4" 
"matrix" "powers4" 
"matrix" "powers4" 
"matrix" "powers4" 
"matrix" "powers4" 
"matrix" "powers4" 
"matrix" "powers4" 
"matrix" "powers4" 
"matrix" "powers4" 
"matrix" "powers4" 
"matrix" "powers4" 
"aperm.default" "aperm" "apply" "t" "powers4" 
"aperm.default" "aperm" "apply" "t" "powers4" 
"aperm.default" "aperm" "apply" "t" "powers4" 
"aperm.default" "aperm" "apply" "t" "powers4" 
"aperm.default" "aperm" "apply" "t" "powers4" 
"aperm.default" "aperm" "apply" "t" "powers4" 
"aperm.default" "aperm" "apply" "t" "powers4" 

字节码编译

x <- runif(10000000)
y <- runif(10000000)
z <- vector(length=10000000)
g <- function() for(i in 1:length(x)) z[i] <<- x[i] - y[i]
system.time(g())
   user  system elapsed 
   2.11    0.00    2.11 
library(compiler)
f <- function() for(i in 1:length(x)) z[i] <<- x[i] + y[i]
cf <- cmpfun(f)
system.time(cf())
   user  system elapsed 
   2.11    0.00    2.12 

内存无法装下数据怎么办

分块,例如使用 read.table()skip 参数

使用 R 软件包来进行内存管理,例如 RMySQL,biglm,ff,bigmemory 等包

参考

学习 R 语言系列文章

快速入门

向量

矩阵和数组

列表

数据框

因子和表

编程结构

数学运算与模拟

面向对象编程

输入与输出

字符串操作

基础绘图

本文代码请访问如下项目:

https://github.com/perillaroc/the-art-of-r-programming