Spark StringIndexer和IndexToString

1、StringIndexer

  标签索引器,它将标签的字符串列映射到标签索引的ML列。
  如果输入列为数字,则将其强制转换为字符串并为字符串值编制索引。
  索引在[0,numLabels)中。 默认情况下,按标签频率排序,因此最常使用的标签的索引为0。

//定义一个StringIndexerModel,将label转换成indexedlabel
StringIndexerModel labelIndexerModel=new StringIndexer().
                setInputCol("label")
                .setOutputCol("indexedLabel")
                .fit(rawData);
//加labelIndexerModel加入到Pipeline中
Pipeline pipeline=new Pipeline()
                 .setStages(new PipelineStage[]
                         {labelIndexerModel,
                         featureIndexerModel,
                         dtClassifier,
                         converter});
//查看结果
pipeline.fit(rawData).transform(rawData).select("label","indexedLabel").show(20,false);
 
按label出现的频次,转换成0~num numOfLabels-1(分类个数),频次最高的转换为0,以此类推:
label=3,出现次数最多,出现了4次,转换(编号)为0
其次是label=2,出现了3次,编号为1,以此类推
+-----+------------+
|label|indexedLabel|
+-----+------------+
|3.0  |0.0         |
|4.0  |3.0         |
|1.0  |2.0         |
|3.0  |0.0         |
|2.0  |1.0         |
|3.0  |0.0         |
|2.0  |1.0         |
|3.0  |0.0         |
|2.0  |1.0         |
|1.0  |2.0         |
+-----+------------+
StringIndexer对String按频次进行编号
     id | category | categoryIndex
    ----|----------|---------------
     0  | a        | 0.0
     1  | b        | 2.0
     2  | c        | 1.0
     3  | a        | 0.0
     4  | a        | 0.0
     5  | c        | 1.0
     如果转换模型(关系)是基于上面数据得到的 (a,b,c)->(0.0,2.0,1.0),如果用此模型转换category多于(a,b,c)的数据,比如多了d,e,就会遇到麻烦:
     id | category | categoryIndex
    ----|----------|---------------
     0  | a        | 0.0
     1  | b        | 2.0
     2  | d        | ?
     3  | e        | ?
     4  | a        | 0.0
     5  | c        | 1.0
     Spark提供了两种处理方式:
     StringIndexerModel labelIndexerModel=new StringIndexer().
                    setInputCol("label")
                    .setOutputCol("indexedLabel")
                    //.setHandleInvalid("error")
                    .setHandleInvalid("skip")
                    .fit(rawData);
     (1)默认设置,也就是.setHandleInvalid("error"):会抛出异常
     org.apache.spark.SparkException: Unseen label: d,e
     (2).setHandleInvalid("skip") 忽略这些label所在行的数据,正常运行,将输出如下结果:
     id | category | categoryIndex
    ----|----------|---------------
     0  | a        | 0.0
     1  | b        | 2.0
     4  | a        | 0.0
     5  | c        | 1.0

2、IndexToString

是一个转换器“ Transformer”,它将一列索引映射回对应字符串值的新列。

索引字符串映射既可以来自输入列的ML属性,也可以来自用户提供的标签(优先于ML属性)。

相应的,有StringIndexer,就应该有IndexToString。
在应用StringIndexer对labels进行重新编号后,带着这些编号后的label对数据进行了训练,并接着对其他数据进行了预测,
得到预测结果,预测结果的label也是重新编号过的,因此需要转换回来。
见下面例子,转换回来的convetedPrediction才和原始的label对应。
IndexToString converter=new IndexToString()
                .setInputCol("prediction")//Spark默认预测label行
                .setOutputCol("convetedPrediction")//转换回来的预测label
                .setLabels(labelIndexerModel.labels());//需要指定前面建好相互相互模型
Pipeline pipeline=new Pipeline()
                 .setStages(new PipelineStage[]
                         {labelIndexerModel,
                         featureIndexerModel,
                         dtClassifier,
                         converter});
pipeline.fit(rawData).transform(rawData)
        .select("label","prediction","convetedPrediction").show(20,false);  
|label|prediction|convetedPrediction|
+-----+----------+------------------+
|3.0  |0.0       |3.0               |
|4.0  |1.0       |2.0               |
|1.0  |2.0       |1.0               |
|3.0  |0.0       |3.0               |
|2.0  |1.0       |2.0               |
|3.0  |0.0       |3.0               |
|2.0  |1.0       |2.0               |
|3.0  |0.0       |3.0               |
|2.0  |1.0       |2.0               |
|1.0  |2.0       |1.0               |
+-----+----------+------------------+

相关推荐