首页 文章

data.table row-wise sum,mean,min,max如dplyr?

提问于
浏览
19

在datatable上还有其他关于行方式运算符的帖子 . 他们是too simple或解决specific scenario

我的问题更通用 . 有一个使用dplyr的解决方案 . 我玩过但未能找到使用data.table语法的等效解决方案 . 您能否建议一个优雅的data.table解决方案,重现与dplyr版本相同的结果?

EDIT 1 :真实数据集建议解决方案的基准测试摘要(10MB,73000行,24个数字列上的统计数据) . 基准测试结果是主观的 . 但是,经过的时间始终可以再现 .

| Solution By | Speed compared to dplyr     |
|-------------|-----------------------------|
| Metrics v1  |  4.3 times SLOWER (use .SD) |
| Metrics v2  |  5.6 times FASTER           |
| ExperimenteR| 15   times FASTER           |
| Arun v1     |  3   times FASTER (Map func)|
| Arun v2     |  3   times FASTER (foo func)|
| Ista        |  4.5 times FASTER           |

EDIT 2 :我在第二天添加了NACount列 . 这就是为什么在各个贡献者建议的解决方案中找不到该列的原因 .

Data Setup

library(data.table)
dt <- data.table(ProductName = c("Lettuce", "Beetroot", "Spinach", "Kale", "Carrot"),
    Country = c("CA", "FR", "FR", "CA", "CA"),
    Q1 = c(NA, 61, 40, 54, NA), Q2 = c(22,  8, NA,  5, NA),
    Q3 = c(51, NA, NA, 16, NA), Q4 = c(79, 10, 49, NA, NA))

#    ProductName Country Q1 Q2 Q3 Q4
# 1:     Lettuce      CA NA 22 51 79
# 2:    Beetroot      FR 61  8 NA 10
# 3:     Spinach      FR 40 NA NA 49
# 4:        Kale      CA 54  5 16 NA
# 5:      Carrot      CA NA NA NA NA

SOLUTION using dplyr + rowwise()

library(dplyr) ; library(magrittr)
dt %>% rowwise() %>% 
    transmute(ProductName, Country, Q1, Q2, Q3, Q4,
     AVG = mean(c(Q1, Q2, Q3, Q4), na.rm=TRUE),
     MIN = min (c(Q1, Q2, Q3, Q4), na.rm=TRUE),
     MAX = max (c(Q1, Q2, Q3, Q4), na.rm=TRUE),
     SUM = sum (c(Q1, Q2, Q3, Q4), na.rm=TRUE),
     NAcnt= sum(is.na(c(Q1, Q2, Q3, Q4))))

#   ProductName Country Q1 Q2 Q3 Q4      AVG MIN  MAX SUM NAcnt
# 1     Lettuce      CA NA 22 51 79 50.66667  22   79 152     1
# 2    Beetroot      FR 61  8 NA 10 26.33333   8   61  79     1
# 3     Spinach      FR 40 NA NA 49 44.50000  40   49  89     2
# 4        Kale      CA 54  5 16 NA 25.00000   5   54  75     1
# 5      Carrot      CA NA NA NA NA      NaN Inf -Inf   0     4

ERROR with data.table (compute entire column instead of per-row)

dt[, .(ProductName, Country, Q1, Q2, Q3, Q4,
    AVG = mean(c(Q1, Q2, Q3, Q4), na.rm=TRUE),
    MIN = min (c(Q1, Q2, Q3, Q4), na.rm=TRUE),
    MAX = max (c(Q1, Q2, Q3, Q4), na.rm=TRUE),
    SUM = sum (c(Q1, Q2, Q3, Q4), na.rm=TRUE),
    NAcnt= sum(is.na(c(Q1, Q2, Q3, Q4))))]

#    ProductName Country Q1 Q2 Q3 Q4      AVG MIN MAX SUM NAcnt
# 1:     Lettuce      CA NA 22 51 79 35.90909   5  79 395     9
# 2:    Beetroot      FR 61  8 NA 10 35.90909   5  79 395     9
# 3:     Spinach      FR 40 NA NA 49 35.90909   5  79 395     9
# 4:        Kale      CA 54  5 16 NA 35.90909   5  79 395     9
# 5:      Carrot      CA NA NA NA NA 35.90909   5  79 395     9

ALMOST solution but more complex and missing Q1,Q2,Q3,Q4 output columns

dtmelt <- reshape2::melt(dt, id=c("ProductName", "Country"),
            variable.name="Quarter", value.name="Qty")

dtmelt[, .(AVG = mean(Qty, na.rm=TRUE),
    MIN = min (Qty, na.rm=TRUE),
    MAX = max (Qty, na.rm=TRUE),
    SUM = sum (Qty, na.rm=TRUE),
    NAcnt= sum(is.na(Qty))), by = list(ProductName, Country)]

#    ProductName Country      AVG MIN  MAX SUM NAcnt
# 1:     Lettuce      CA 50.66667  22   79 152     1
# 2:    Beetroot      FR 26.33333   8   61  79     1
# 3:     Spinach      FR 44.50000  40   49  89     2
# 4:        Kale      CA 25.00000   5   54  75     1
# 5:      Carrot      CA      NaN Inf -Inf   0     4

4 回答

  • 11

    您可以使用 matrixStats 包中的高效行方式函数 .

    library(matrixStats)
    dt[, `:=`(MIN = rowMins(as.matrix(.SD), na.rm=T),
              MAX = rowMaxs(as.matrix(.SD), na.rm=T),
              AVG = rowMeans(.SD, na.rm=T),
              SUM = rowSums(.SD, na.rm=T)), .SDcols=c(Q1, Q2,Q3,Q4)]
    
    dt
    #    ProductName Country Q1 Q2 Q3 Q4 MIN  MAX      AVG SUM
    # 1:     Lettuce      CA NA 22 51 79  22   79 50.66667 152
    # 2:    Beetroot      FR 61  8 NA 10   8   61 26.33333  79
    # 3:     Spinach      FR 40 NA 79 49  40   79 56.00000 168
    # 4:        Kale      CA 54  5 16 NA   5   54 25.00000  75
    # 5:      Carrot      CA NA NA NA NA Inf -Inf      NaN   0
    

    对于具有500000行的数据集(使用来自CRAN的 data.table

    dt <- rbindlist(lapply(1:100000, function(i)dt))
    system.time(dt[, `:=`(MIN = rowMins(as.matrix(.SD), na.rm=T),
                          MAX = rowMaxs(as.matrix(.SD), na.rm=T),
                          AVG = rowMeans(.SD, na.rm=T),
                          SUM = rowSums(.SD, na.rm=T)), .SDcols=c("Q1", "Q2","Q3","Q4")])
    #  user  system elapsed 
    # 0.089   0.004   0.093
    

    rowwise (或 by=1:nrow(dt) )对于 for loop 是"euphemism",例如

    library(dplyr) ; library(magrittr)
    system.time(dt %>% rowwise() %>% 
      transmute(ProductName, Country, Q1, Q2, Q3, Q4,
                MIN = min (c(Q1, Q2, Q3, Q4), na.rm=TRUE),
                MAX = max (c(Q1, Q2, Q3, Q4), na.rm=TRUE),
                AVG = mean(c(Q1, Q2, Q3, Q4), na.rm=TRUE),
                SUM = sum (c(Q1, Q2, Q3, Q4), na.rm=TRUE)))
    #   user  system elapsed 
    # 80.832   0.111  80.974 
    
    system.time(dt[, `:=`(AVG= mean(as.numeric(.SD),na.rm=TRUE),MIN = min(.SD, na.rm=TRUE),MAX = max(.SD, na.rm=TRUE),SUM = sum(.SD, na.rm=TRUE)),.SDcols=c("Q1", "Q2","Q3","Q4"),by=1:nrow(dt)] )
    #    user  system elapsed 
    # 141.492   0.196 141.757
    
  • 1

    使用 by=1:nrow(dt) ,在 data.table 中执行rowwise操作

    library(data.table)
    dt[, `:=`(AVG= mean(as.numeric(.SD),na.rm=TRUE),MIN = min(.SD, na.rm=TRUE),MAX = max(.SD, na.rm=TRUE),SUM = sum(.SD, na.rm=TRUE)),.SDcols=c(Q1, Q2,Q3,Q4),by=1:nrow(dt)] 
       ProductName Country Q1 Q2 Q3 Q4      AVG MIN  MAX SUM
    1:     Lettuce      CA NA 22 51 79 50.66667  22   79 152
    2:    Beetroot      FR 61  8 NA 10 26.33333   8   61  79
    3:     Spinach      FR 40 NA 79 49 56.00000  40   79 168
    4:        Kale      CA 54  5 16 NA 25.00000   5   54  75
    5:      Carrot      CA NA NA NA NA      NaN Inf -Inf   0
    
    Warning messages:
    1: In min(c(NA_real_, NA_real_, NA_real_, NA_real_), na.rm = TRUE) :
      no non-missing arguments to min; returning Inf
    2: In max(c(NA_real_, NA_real_, NA_real_, NA_real_), na.rm = TRUE) :
      no non-missing arguments to max; returning -Inf
    

    您收到了警告消息,因为在第5行中,您正在计算最大值,总和,最小值和最大值 . 例如,见下文:

    min(c(NA,NA,NA,NA),na.rm=TRUE)
    [1] Inf
    Warning message:
    In min(c(NA, NA, NA, NA), na.rm = TRUE) :
      no non-missing arguments to min; returning Inf
    
  • 21

    apply 函数可用于执行逐行计算 . 分别定义功能可以保持清洁:

    dstats <- function(x){
        c(mean(x,na.rm=TRUE),
          min(x, na.rm=TRUE),
          max(x, na.rm=TRUE),
          sum(x, na.rm=TRUE))
    }
    

    该函数现在可以应用于data.table的行 .

    (dt[,
       c("AVG", "MIN", "MAX", "SUM") := data.frame(t(apply(.SD, 1, dstats))),
       .SDcols=c("Q1", "Q2","Q3","Q4"),
       with = FALSE])
    

    请注意,使用 [.data.table 执行此操作的唯一优点是它允许使用 := 通过引用快速添加 .

    这比_325859解决方案更慢但更灵活,并且比@ExperimenteR的 dplyr 解决方案更快,在36秒时钟(我的其他方法的时间与@ ExperimenteR的答案相似) .

  • 6

    只是另一种方式(虽然效率不高,因为每次调用 na.omit() ,以及许多内存分配):

    require(data.table)
    new_cols = c("MIN", "MAX", "SUM", "AVG")
    dt[, (new_cols) := Map(function(x, f) f(x), 
                           list(na.omit(c(Q1,Q2,Q3,Q4))), 
                           list(min, max, sum, mean)),
       by = 1:nrow(dt)]
    
    #    ProductName Country Q1 Q2 Q3 Q4 MIN  MAX SUM      AVG
    # 1:     Lettuce      CA NA 22 51 79  22   79 152 50.66667
    # 2:    Beetroot      FR 61  8 NA 10   8   61  79 26.33333
    # 3:     Spinach      FR 40 NA 79 49  40   79 168 56.00000
    # 4:        Kale      CA 54  5 16 NA   5   54  75 25.00000
    # 5:      Carrot      CA NA NA NA NA Inf -Inf   0      NaN
    

    但正如我所提到的,一旦 colwise()rowwise() 被实现,这将变得更加简单 . 在这种情况下的语法可能类似于:

    dt[, rowwise(.SD, list(MIN=min, MAX=max, SUM=sum, AVG=mean), na.rm=TRUE), by = 1:nrow(dt)]
    # `by = ` is really not necessary in this case.
    

    对于这种情况甚至更直接:

    rowwise(dt, list(...), na.rm=TRUE)
    

    Edit:

    另一种变化:

    myNACount <- function(x, ...) length(attributes(x)$na.action)
    foo <- function(x, ...) {
        funs = c(min, max, mean, sum, myNACount)
        lapply(funs, function(f) f(x, ...))
    }
    
    dt[, (new_cols) := foo(na.omit(c(Q1, Q2, Q3, Q4)), na.rm=TRUE), by=1:nrow(dt)]
    #    ProductName Country Q1 Q2 Q3 Q4 MIN  MAX      SUM AVG NAs
    # 1:     Lettuce      CA NA 22 51 79  22   79 50.66667 152   1
    # 2:    Beetroot      FR 61  8 NA 10   8   61 26.33333  79   1
    # 3:     Spinach      FR 40 NA NA 49  40   49 44.50000  89   2
    # 4:        Kale      CA 54  5 16 NA   5   54 25.00000  75   1
    # 5:      Carrot      CA NA NA NA NA Inf -Inf      NaN   0   4
    

相关问题