# 逻辑回归算法原理及Spark MLlib调用实例（Scala/Java/python）

2017-12-30 11:43:49来源:oschina作者:hblt-j人点击

＊目前spark.ml逻辑回归工具仅支持二分类问题，多分类回归将在未来完善。

＊当使用无拦截的连续非零列训练LogisticRegressionModel时，Spark MLlib为连续非零列输出零系数。这种处理不同于libsvm与R glmnet相似。

elasticNetParam：

featuresCol:

fitIntercept:

labelCol:

maxIter:

predictionCol:

probabilityCol:

regParam:

standardization:

threshold:

thresholds:

tol:

weightCol:

Scala：

[plain] view plain copy

.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8)//Fitthemodel
vallrModel=lr.fit(training)//Printthecoefficientsandinterceptforlogisticregression
println(s"Coefficients:\${lrModel.coefficients}Intercept:\${lrModel.intercept}")

Java：

[java] view plain copy

importorg.apache.spark.ml.classification.LogisticRegression;
importorg.apache.spark.ml.classification.LogisticRegressionModel;
importorg.apache.spark.sql.Dataset;
importorg.apache.spark.sql.Row;
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8);//Fitthemodel
LogisticRegressionModellrModel=lr.fit(training);//Printthecoefficientsandinterceptforlogisticregression
System.out.println("Coefficients:"
+lrModel.coefficients()+"Intercept:"+lrModel.intercept());

Python：

[python] view plain copy

lrModel=lr.fit(training)#Printthecoefficientsandinterceptforlogisticregression
print("Coefficients:"+str(lrModel.coefficients))
print("Intercept:"+str(lrModel.intercept))

spark.ml逻辑回归工具同样支持提取模总结。LogisticRegressionTrainingSummary提供LogisticRegressionModel的总结。目前仅支持二分类问题，所以总结必须明确投掷到BinaryLogisticRegressionTrainingSummary。支持多分类问题后可能有所改善。

Scala：

[plain] view plain copy

importorg.apache.spark.ml.Pipeline
importorg.apache.spark.ml.classification.DecisionTreeClassificationModel
importorg.apache.spark.ml.classification.DecisionTreeClassifier
importorg.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
//Fitonwholedatasettoincludealllabelsinindex.
vallabelIndexer=newStringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(data)
//Automaticallyidentifycategoricalfeatures,andindexthem.
valfeatureIndexer=newVectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4)//featureswith>4distinctvaluesaretreatedascontinuous.
.fit(data)//Splitthedataintotrainingandtestsets(30%heldoutfortesting).
valdt=newDecisionTreeClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")//Convertindexedlabelsbacktooriginallabels.
vallabelConverter=newIndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels)//ChainindexersandtreeinaPipeline.
valpipeline=newPipeline()
.setStages(Array(labelIndexer,featureIndexer,dt,labelConverter))//Trainmodel.Thisalsorunstheindexers.
valmodel=pipeline.fit(trainingData)//Makepredictions.
valpredictions=model.transform(testData)//Selectexamplerowstodisplay.
predictions.select("predictedLabel","label","features").show(5)//Select(prediction,truelabel)andcomputetesterror.
valevaluator=newMulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy")
valaccuracy=evaluator.evaluate(predictions)
println("TestError="+(1.0-accuracy))valtreeModel=model.stages(2).asInstanceOf[DecisionTreeClassificationModel]
println("Learnedclassificationtreemodel:/n"+treeModel.toDebugString)

Java：

[python] view plain copy

importorg.apache.spark.ml.Pipeline;
importorg.apache.spark.ml.PipelineModel;
importorg.apache.spark.ml.PipelineStage;
importorg.apache.spark.ml.classification.DecisionTreeClassifier;
importorg.apache.spark.ml.classification.DecisionTreeClassificationModel;
importorg.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
importorg.apache.spark.ml.feature.*;
importorg.apache.spark.sql.Dataset;
importorg.apache.spark.sql.Row;
Datasetdata=spark
.format("libsvm")
//Fitonwholedatasettoincludealllabelsinindex.
StringIndexerModellabelIndexer=newStringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(data);//Automaticallyidentifycategoricalfeatures,andindexthem.
VectorIndexerModelfeatureIndexer=newVectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4)//featureswith>4distinctvaluesaretreatedascontinuous.
.fit(data);//Splitthedataintotrainingandtestsets(30%heldoutfortesting).
Dataset[]splits=data.randomSplit(newdouble[]{0.7,0.3});
DatasettrainingData=splits[0];
DecisionTreeClassifierdt=newDecisionTreeClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures");//Convertindexedlabelsbacktooriginallabels.
IndexToStringlabelConverter=newIndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels());//ChainindexersandtreeinaPipeline.
Pipelinepipeline=newPipeline()
.setStages(newPipelineStage[]{labelIndexer,featureIndexer,dt,labelConverter});//Trainmodel.Thisalsorunstheindexers.
PipelineModelmodel=pipeline.fit(trainingData);//Makepredictions.
Datasetpredictions=model.transform(testData);//Selectexamplerowstodisplay.
predictions.select("predictedLabel","label","features").show(5);//Select(prediction,truelabel)andcomputetesterror.
MulticlassClassificationEvaluatorevaluator=newMulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy");
doubleaccuracy=evaluator.evaluate(predictions);
System.out.println("TestError="+(1.0-accuracy));DecisionTreeClassificationModeltreeModel=
(DecisionTreeClassificationModel)(model.stages()[2]);
System.out.println("Learnedclassificationtreemodel:/n"+treeModel.toDebugString());