逻辑回归算法原理及Spark MLlib调用实例(Scala/Java/python)

2017-12-30 11:43:49来源:oschina作者:hblt-j人点击

分享

逻辑回归


算法原理:


逻辑回归是一个流行的二分类问题预测方法。它是Generalized Linear models的一个特殊应用以预测结果概率。它是一个线性模型如下列方程所示,其中损失函数为逻辑损失:


对于二分类问题,算法产出一个二值逻辑回归模型。给定一个新数据,由x表示,则模型通过下列逻辑方程来预测:


其中 。默认情况下,如果,结果为正,否则为负。和线性SVMs不同,逻辑回归的原始输出有概率解释(x为正的概率)。


二分类逻辑回归可以扩展为多分类逻辑回归来训练和预测多类别分类问题。如一个分类问题有K种可能结果,我们可以选取其中一种结果作为“中心点“,其他K-1个结果分别视为中心点结果的对立点。在spark.mllib中,取第一个类别为中心点类别。


*目前spark.ml逻辑回归工具仅支持二分类问题,多分类回归将在未来完善。


*当使用无拦截的连续非零列训练LogisticRegressionModel时,Spark MLlib为连续非零列输出零系数。这种处理不同于libsvm与R glmnet相似。


参数:


elasticNetParam:


类型:双精度型。


含义:弹性网络混合参数,范围[0,1]。


featuresCol:


类型:字符串型。


含义:特征列名。


fitIntercept:


类型:布尔型。


含义:是否训练拦截对象。


labelCol:


类型:字符串型。


含义:标签列名。


maxIter:


类型:整数型。


含义:最多迭代次数(>=0)。


predictionCol:


类型:字符串型。


含义:预测结果列名。


probabilityCol:


类型:字符串型。


含义:用以预测类别条件概率的列名。


regParam:


类型:双精度型。


含义:正则化参数(>=0)。


standardization:


类型:布尔型。


含义:训练模型前是否需要对训练特征进行标准化处理。


threshold:


类型:双精度型。


含义:二分类预测的阀值,范围[0,1]。


thresholds:


类型:双精度数组型。


含义:多分类预测的阀值,以调整预测结果在各个类别的概率。


tol:


类型:双精度型。


含义:迭代算法的收敛性。


weightCol:


类型:字符串型。


含义:列权重。


示例:


下面的例子展示如何训练使用弹性网络正则化的逻辑回归模型。elasticNetParam对应于 ,regParam对应于 。


Scala:

[plain] view plain copy



importorg.apache.spark.ml.classification.LogisticRegression//Loadtrainingdata
valtraining=spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")vallr=newLogisticRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8)//Fitthemodel
vallrModel=lr.fit(training)//Printthecoefficientsandinterceptforlogisticregression
println(s"Coefficients:${lrModel.coefficients}Intercept:${lrModel.intercept}")

Java:


[java] view plain copy



importorg.apache.spark.ml.classification.LogisticRegression;
importorg.apache.spark.ml.classification.LogisticRegressionModel;
importorg.apache.spark.sql.Dataset;
importorg.apache.spark.sql.Row;
importorg.apache.spark.sql.SparkSession;//Loadtrainingdata
Datasettraining=spark.read().format("libsvm")
.load("data/mllib/sample_libsvm_data.txt");LogisticRegressionlr=newLogisticRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8);//Fitthemodel
LogisticRegressionModellrModel=lr.fit(training);//Printthecoefficientsandinterceptforlogisticregression
System.out.println("Coefficients:"
+lrModel.coefficients()+"Intercept:"+lrModel.intercept());

Python:


[python] view plain copy



frompyspark.ml.classificationimportLogisticRegression#Loadtrainingdata
training=spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")lr=LogisticRegression(maxIter=10,regParam=0.3,elasticNetParam=0.8)#Fitthemodel
lrModel=lr.fit(training)#Printthecoefficientsandinterceptforlogisticregression
print("Coefficients:"+str(lrModel.coefficients))
print("Intercept:"+str(lrModel.intercept))

spark.ml逻辑回归工具同样支持提取模总结。LogisticRegressionTrainingSummary提供LogisticRegressionModel的总结。目前仅支持二分类问题,所以总结必须明确投掷到BinaryLogisticRegressionTrainingSummary。支持多分类问题后可能有所改善。


继续上面的例子:


Scala:

[plain] view plain copy



importorg.apache.spark.ml.Pipeline
importorg.apache.spark.ml.classification.DecisionTreeClassificationModel
importorg.apache.spark.ml.classification.DecisionTreeClassifier
importorg.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
importorg.apache.spark.ml.feature.{IndexToString,StringIndexer,VectorIndexer}//LoadthedatastoredinLIBSVMformatasaDataFrame.
valdata=spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")//Indexlabels,addingmetadatatothelabelcolumn.
//Fitonwholedatasettoincludealllabelsinindex.
vallabelIndexer=newStringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(data)
//Automaticallyidentifycategoricalfeatures,andindexthem.
valfeatureIndexer=newVectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4)//featureswith>4distinctvaluesaretreatedascontinuous.
.fit(data)//Splitthedataintotrainingandtestsets(30%heldoutfortesting).
valArray(trainingData,testData)=data.randomSplit(Array(0.7,0.3))//TrainaDecisionTreemodel.
valdt=newDecisionTreeClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")//Convertindexedlabelsbacktooriginallabels.
vallabelConverter=newIndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels)//ChainindexersandtreeinaPipeline.
valpipeline=newPipeline()
.setStages(Array(labelIndexer,featureIndexer,dt,labelConverter))//Trainmodel.Thisalsorunstheindexers.
valmodel=pipeline.fit(trainingData)//Makepredictions.
valpredictions=model.transform(testData)//Selectexamplerowstodisplay.
predictions.select("predictedLabel","label","features").show(5)//Select(prediction,truelabel)andcomputetesterror.
valevaluator=newMulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy")
valaccuracy=evaluator.evaluate(predictions)
println("TestError="+(1.0-accuracy))valtreeModel=model.stages(2).asInstanceOf[DecisionTreeClassificationModel]
println("Learnedclassificationtreemodel:/n"+treeModel.toDebugString)

Java:


[python] view plain copy



importorg.apache.spark.ml.Pipeline;
importorg.apache.spark.ml.PipelineModel;
importorg.apache.spark.ml.PipelineStage;
importorg.apache.spark.ml.classification.DecisionTreeClassifier;
importorg.apache.spark.ml.classification.DecisionTreeClassificationModel;
importorg.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
importorg.apache.spark.ml.feature.*;
importorg.apache.spark.sql.Dataset;
importorg.apache.spark.sql.Row;
importorg.apache.spark.sql.SparkSession;//LoadthedatastoredinLIBSVMformatasaDataFrame.
Datasetdata=spark
.read()
.format("libsvm")
.load("data/mllib/sample_libsvm_data.txt");//Indexlabels,addingmetadatatothelabelcolumn.
//Fitonwholedatasettoincludealllabelsinindex.
StringIndexerModellabelIndexer=newStringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(data);//Automaticallyidentifycategoricalfeatures,andindexthem.
VectorIndexerModelfeatureIndexer=newVectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4)//featureswith>4distinctvaluesaretreatedascontinuous.
.fit(data);//Splitthedataintotrainingandtestsets(30%heldoutfortesting).
Dataset[]splits=data.randomSplit(newdouble[]{0.7,0.3});
DatasettrainingData=splits[0];
DatasettestData=splits[1];//TrainaDecisionTreemodel.
DecisionTreeClassifierdt=newDecisionTreeClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures");//Convertindexedlabelsbacktooriginallabels.
IndexToStringlabelConverter=newIndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels());//ChainindexersandtreeinaPipeline.
Pipelinepipeline=newPipeline()
.setStages(newPipelineStage[]{labelIndexer,featureIndexer,dt,labelConverter});//Trainmodel.Thisalsorunstheindexers.
PipelineModelmodel=pipeline.fit(trainingData);//Makepredictions.
Datasetpredictions=model.transform(testData);//Selectexamplerowstodisplay.
predictions.select("predictedLabel","label","features").show(5);//Select(prediction,truelabel)andcomputetesterror.
MulticlassClassificationEvaluatorevaluator=newMulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy");
doubleaccuracy=evaluator.evaluate(predictions);
System.out.println("TestError="+(1.0-accuracy));DecisionTreeClassificationModeltreeModel=
(DecisionTreeClassificationModel)(model.stages()[2]);
System.out.println("Learnedclassificationtreemodel:/n"+treeModel.toDebugString());

微信扫一扫

第七城市微信公众平台