RandomForests(随机森林)
RandomForests是一种集成模型(Ensemble),它通过将一组基础决策树(DecisionTree)模型的判别结果组合起来,从而进行最终的分类或者回归。相比单个的DecisionTree模型,RandomForests的好处是可以非常好的避免过拟合(overfitting)问题,这一优点也是所有集成模型的优点;提高RandomForests中基础决策树的数量也能提升模型性能,但模型生成时间就会变长。
和DecisionTree一样,Spark中根据RandomForests的用法,实现了2个Estimator类,分别是用作分类(处理label是类别)的RandomForestClassifier和用作回归(处理label是连续数值)的RandomForestRegressor,从而两者也在fit训练数据后产出不同的模型对象,RandomForestClassificationModel和RandomForestRegressionModel。两者的区别主要在于各自使用的基础决策树模型和最终的模型结果组合算法上。数据集被模型对象transform后的输出都包含了最终预测的分类或数值(prediction),另外RandomForestClassificationModel额外输出各个基础决策树得到的分类结果的和(rawPrediction)以及这个数据的归一化值(probability)。
伪算法过程
Bagging 和 Boosting
集成模型通常有2种:Bagging 和 Boosting。Bagging的思想是并行的用不同的策略生成多个基础模型,然后在投票阶段进行结果组合。而Boosting的思想是从前一个基础模型的准确率出发,调整测试数据的权重或者优化目标,从而生成下一个进化的基础模型,最终按照这些基础模型的权重得出最终结果。RandomForests是一种Bagging算法,它会按照不同的划分策略(随机选择参与备选划分的特征),针对训练数据的不同的抽样子集(比如bootstrapping抽样),生成不同的决策树,最后根据这些决策树的结果组合出最终结果。
投票
集成算法最终要根据不同的基础决策树来计算最终的结果,这个过程也可以被称作投票。RandomForests被作为分类器时,投票算法就是简单的看那个分类结果得到的票数多。当被作为回归器的时候,就会取各个基础决策树的预测结果的平均值作为最终结果。
参数
RandomForests会在计算过程中使用DecisionTree算法,所以很多DecisionTree中使用的参数在这里也适用,就不一一列出了。
输入输出相关:
- RandomForestClassificationModel.rawPredictionCol: 输出结果数据中存放各个决策树得到的各个label的个数的字段名称 (默认值: rawPrediction)
- RandomForestClassificationModel.probabilityCol: 输出结果数据中存放
rawPredictionCol
归一化值的字段名称 (默认值: probability) - RandomForestClassificationModel.predictionCol: 输出结果数据中最终判别类别的字段名称 (默认值: prediction)
- RandomForestRegressionModel.predictionCol: 输出结果数据中最终预测的数据值的字段名称 (默认值: prediction)
算法模型相关:
- featureSubsetStrategy: 节点分裂时特性选取策略,减少备选特性数量可以加快模型生成速度,但太低会导致模型性能下降(默认值:auto):
可选项1. auto-系统自动选择,如果numTrees参数值为1(单颗树,非森林),则选择all;如果numTrees参数值大于1,则分类时选择sqrt,回归时选择onethird。
可选项2. all-选所有特征。
可选项3. onethird-随机选1/3的特征。
可选项4. sqrt-随机选开方个数特征。
可选项5. log2-随机选以2为底的对数个特征。
可选项6. (0,1]之间的数字f-随机选总数*f个特征。 - numTrees: 随机森林包含的决策树个数,提高这个值能提升模型性能,但模型生成时间就会线性增长。(>=1整数,默认值: 20)
- subsamplingRate: 随机抽样率,生成每棵决策树时选取的测试数据集的抽样子集,子集越小,模型生成速度越快。建议使用全集,不做抽样。((0,1]之间的实数,默认值: 1.0)
model对象的成员
- totalNumNodes: 随机森林包含的节点总数,包括所有决策树的内部节点和叶子节点。
- trees: 随机森林中所有树模型对象的集合,类型是
DecisionTreeClassificationModel
或者DecisionTreeRegressionModel
的数组。 - treeWeights: 随机森林中每棵树的权重。
- toDebugString: 输出整个森林的树形结构。
例子
1 | //======================= classifier ======================= |