首页 文章

如何使用Apache Spark执行简单的网格搜索

提问于
浏览
1

我尝试使用Scikit Learn的GridSearch类来调整逻辑回归算法的超参数 .

然而,即使并行使用多个作业,GridSearch也需要几天的时间来处理,除非您只调整一个参数 . 我想过使用Apache Spark来加速这个过程,但我有两个问题 .

  • 为了使用Apache Spark,你真的需要多台机器来分配工作负载吗?例如,如果你只有1台笔记本电脑,使用Apache Spark是没有意义的吗?

  • 在Apache Spark中有一种简单的方法可以使用Scikit Learn的GridSearch吗?

我已经阅读了文档,但它讨论了在整个机器学习管道上运行并行工作程序,但我只是希望它用于参数调整 .

Imports

import datetime
%matplotlib inline

import pylab
import pandas as pd
import math
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.pylab as pylab

import numpy as np
import statsmodels.api as sm
from statsmodels.formula.api import ols

from sklearn import datasets, tree, metrics, model_selection
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import KNeighborsClassifier 
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from sklearn.linear_model import LogisticRegression, LinearRegression, Perceptron
from sklearn.feature_selection import SelectKBest, chi2, VarianceThreshold, RFE
from sklearn.svm import SVC
from sklearn.cross_validation import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, VotingClassifier
from sklearn.naive_bayes import GaussianNB

import findspark
findspark.init()
import pyspark
sc = pyspark.SparkContext()

from datetime import datetime as dt
import scipy
import itertools

ucb_w_reindex = pd.read_csv('clean_airbnb.csv')
ucb = pd.read_csv('clean_airbnb.csv')

pylab.rcParams[ 'figure.figsize' ] = 15 , 10
plt.style.use("fivethirtyeight")

new_style = {'grid': False}
plt.rc('axes', **new_style)

Algorithm Hyper Parameter Tuning

X = ucb.drop('country_destination', axis=1)
y = ucb['country_destination'].values

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = .3, random_state=42, stratify=y)

knn = KNeighborsClassifier()

parameters = {'leaf_size': range(1, 100), 'n_neighbors': range(1, 10), 'weights': ['uniform', 'distance'], 
              'algorithm': ['kd_tree', 'ball_tree', 'brute', 'auto']}


# ======== What I want to do in Apache Spark ========= #

%%time
parameters = {'n_neighbors': range(1, 100)}
clf1 = GridSearchCV(estimator=knn, param_grid=parameters, n_jobs=5).fit(X_train, y_train)
best = clf1.best_estimator_

# ==================================================== #

1 回答

  • 1

    您可以使用名为spark-sklearn的库来运行分布式参数扫描 . 您需要一组机器或一台多CPU机器才能获得并行加速 .

    希望这可以帮助,

    Roope - 微软MMLSpark团队

相关问题