首页 文章

在R中拟合Spark ML逻辑回归时的ArrayIndexOutOfBoundsException

提问于
浏览
2

我正在尝试使用 sparklyr::ml_logistic_regression 来拟合逻辑回归模型 . 我的训练数据集包含42,457行和785列;响应是 label 列中的0/1整数,所有剩余列都是0/1整数特征 . 我的源数据位于R数据框( df )中,我可以使用 glm(label ~ ., data = df, family = binomial) 在基础R中成功地拟合模型 .

不幸的是我无法使用 ml_logistic_regression 来适应这个模型 . 代码如下; sc 是现有的Spark连接 .

library(sparklyr)
library(tidyverse)

copy_to(sc, df, "spark_train", overwrite = TRUE)
train_tbl <- tbl(sc, "spark_train")
fit <- ml_logistic_regression(train_tbl, label ~ .)

这是一个堆栈跟踪:

d> fit <- ml_logistic_regression(train_tbl, label ~ .)
* No rows dropped by 'na.omit' call
Error: java.lang.ArrayIndexOutOfBoundsException: 1
    at org.apache.spark.ml.classification.LogisticRegression.train(LogisticRegression.scala:343)
    at org.apache.spark.ml.classification.LogisticRegression.train(LogisticRegression.scala:159)
    at org.apache.spark.ml.Predictor.fit(Predictor.scala:90)
    at org.apache.spark.ml.Predictor.fit(Predictor.scala:71)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(Unknown Source)
    at java.lang.reflect.Method.invoke(Unknown Source)
    at sparklyr.Invoke$.invoke(invoke.scala:94)
    at sparklyr.StreamHandler$.handleMethodCall(stream.scala:89)
    at sparklyr.StreamHandler$.read(stream.scala:55)
    at sparklyr.BackendHandler.channelRead0(handler.scala:49)
    at sparklyr.BackendHandler.channelRead0(handler.scala:14)
    at io.netty.channel.SimpleChannelInboundHandler.channelRead(SimpleChannelInboundHandler.java:105)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
    at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
    at io.netty.handler.codec.MessageToMessageDecoder.channelRead(MessageToMessageDecoder.java:103)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
    at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
    at io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:244)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
    at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
    at io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:846)
    at io.netty.channel.nio.AbstractNioByteChannel$NioByteUnsafe.read(AbstractNioByteChannel.java:131)
    at io.netty.channel.nio.NioEventLoop.processSelectedKey(NioEventLoop.java:511)
    at io.netty.channel.nio.NioEventLoop.processSelectedKeysOptimized(NioEventLoop.java:468)
    at io.netty.channel.nio.NioEventLoop.processSelectedKeys(NioEventLoop.java:382)
    at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:354)
    at io.netty.util.concurrent.SingleThreadEventExecutor$2.run(SingleThreadEventExecutor.java:111)
    at io.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:137)
    at java.lang.Thread.run(Unknown Source)

这是我的 sessionInfo()

R version 3.3.2 (2016-10-31)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows >= 8 x64 (build 9200)

locale:
[1] LC_COLLATE=English_United Kingdom.1252  LC_CTYPE=English_United Kingdom.1252   
[3] LC_MONETARY=English_United Kingdom.1252 LC_NUMERIC=C                           
[5] LC_TIME=English_United Kingdom.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] dplyr_0.7.1      purrr_0.2.2.2    readr_1.0.0      tidyr_0.6.3     
 [5] tibble_1.3.3     ggplot2_2.2.1    tidyverse_1.1.1  sparklyr_0.5.6  
 [9] robomarker_0.1.0 devtools_1.12.0 

loaded via a namespace (and not attached):
 [1] h2o_3.10.5.2     reshape2_1.4.2   haven_1.0.0      lattice_0.20-34 
 [5] colorspace_1.3-2 htmltools_0.3.5  yaml_2.1.14      base64enc_0.1-3 
 [9] rlang_0.1.1      foreign_0.8-67   glue_1.1.1       withr_1.0.2     
[13] DBI_0.7          rappdirs_0.3.1   dbplyr_1.0.0     modelr_0.1.0    
[17] readxl_1.0.0     bindrcpp_0.2     bindr_0.1        plyr_1.8.4      
[21] stringr_1.2.0    munsell_0.4.3    commonmark_1.1   gtable_0.2.0    
[25] cellranger_1.1.0 rvest_0.3.2      psych_1.7.3.21   memoise_1.0.0   
[29] forcats_0.2.0    httpuv_1.3.3     parallel_3.3.2   broom_0.4.2     
[33] Rcpp_0.12.10     xtable_1.8-2     backports_1.0.5  scales_0.4.1    
[37] jsonlite_1.2     config_0.2       mime_0.5         mnormt_1.5-5    
[41] hms_0.3          digest_0.6.12    stringi_1.1.2    shiny_1.0.3     
[45] grid_3.3.2       rprojroot_1.2    bitops_1.0-6     tools_3.3.2     
[49] magrittr_1.5     RCurl_1.95-4.8   lazyeval_0.2.0   pkgconfig_2.0.1 
[53] xml2_1.1.1       lubridate_1.6.0  assertthat_0.1   roxygen2_6.0.1  
[57] httr_1.2.1       rstudioapi_0.6   R6_2.2.0         rsparkling_0.2.0
[61] nlme_3.1-128

知道为什么会这样吗?

1 回答

  • 1

    此错误可能是由于训练数据集中只有一种类型的标签引起的 . 检查以确保您有多种标签类型;根据您的火花版本,您可能只能使用两个标签(即0和1,用于二项式回归) .

相关问题