[Mahout] 使用Mahout 对Kddcup 1999的数据进行分析 -- Naive Bayes
通常,在网上找到的mahout的naive bayes的例子跟官网的例子,都是针对20 newsgroup. 而且通常是命令行版本。虽然能得出预测、分类结果,但是对于Bayes具体是如何工作,以及如何处理自己的数据会比较茫然。
这个例子就是使用mahout 0.9 对kddcup 1999 的数据进行分析。
第一步: 下载数据。
地址: http://kdd.ics.uci.edu/databases/kddcup99/
第二步: 将原始文件转换成Hadoop使用的sequence 文件。
我们从官网知道,Bayes在mahout之中只有基于map-reduce的实现。 参考: https://mahout.apache.org/users/basics/algorithms.html 所以我们必须要将csv文件转换成hadoop使用的sequence文件
package experiment.kdd99_bayes; import java.io.FileReader; import java.io.IOException; import java.util.List; import java.util.Map; import org.apache.commons.lang3.StringUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.io.Text; import org.apache.mahout.math.RandomAccessSparseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; import au.com.bytecode.opencsv.CSVReader; import com.google.common.collect.Lists; import com.google.common.collect.Maps; public class Kdd99CsvToSeqFile { private String csvPath; private Path seqPath; private SequenceFile.Writer writer; private Configuration conf = new Configuration(); private Map<String, Long> word2LongMap = Maps.newHashMap(); private List<String> strLabelList = Lists.newArrayList(); private FileSystem fs = null; public Kdd99CsvToSeqFile(String csvFilePath, String seqPath) { this.csvPath = csvFilePath; this.seqPath = new Path(seqPath); } public Map<String, Long> getWordMap() { return word2LongMap; } public List<String> getLabelList() { return strLabelList; } /** * Show out the already sequenced file content */ public void dump() { try { fs = FileSystem.get(conf); SequenceFile.Reader reader = new SequenceFile.Reader(fs, this.seqPath, conf); Text key = new Text(); VectorWritable value = new VectorWritable(); while (reader.next(key, value)) { System.out.println( "reading key:" + key.toString() +" with value " + value.toString()); } reader.close(); } catch (IOException e) { e.printStackTrace(); } finally { try { fs.close(); fs = null; } catch (IOException e) { e.printStackTrace(); } } } /** * Sequence target csv file. * @param labelIndex * @param hasHeader */ public void parse(int labelIndex, boolean hasHeader) { CSVReader reader = null; try { fs = FileSystem.getLocal(conf); if(fs.exists(this.seqPath)) fs.delete(this.seqPath, true); writer = SequenceFile.createWriter(fs, conf, this.seqPath, Text.class, VectorWritable.class); reader = new CSVReader(new FileReader(this.csvPath)); String[] header = null; if(hasHeader) header = reader.readNext(); String[] line = null; Long l = 0L; while((line = reader.readNext()) != null) { if(labelIndex > line.length) break; l++; List<String> tmpList = Lists.newArrayList(line); String label = tmpList.get(labelIndex); if(!strLabelList.contains(label)) strLabelList.add(label); // Text key = new Text("/" + label + "/" + l); Text key = new Text("/" + label + "/"); tmpList.remove(labelIndex); VectorWritable vectorWritable = new VectorWritable(); Vector vector = new RandomAccessSparseVector(tmpList.size(), tmpList.size());//??? for(int i = 0; i < tmpList.size(); i++) { String tmpStr = tmpList.get(i); if(StringUtils.isNumeric(tmpStr)) vector.set(i, Double.parseDouble(tmpStr)); else vector.set(i, parseStrCell(tmpStr)); } vectorWritable.set(vector); writer.append(key, vectorWritable); } } catch (IOException e) { e.printStackTrace(); } finally { try { fs.close(); fs = null; writer.close(); reader.close(); } catch (IOException e) { e.printStackTrace(); } } } private Long parseStrCell(String str) { Long id = word2LongMap.get(str); if( id == null) { id = (long) (word2LongMap.size() + 1); word2LongMap.put(str, id); } return id; } }
1. 初始化hadoop,比如Configuration 、 FileSystem。
2. 通过Hadoop的 Sequence.Writer进行sequence文件的写入。其中的key/value 分别是Text 跟VectorWritable类型。
3. 通过CSVReader读入CSV文件,然后逐行遍历。如果是带标题的,则先略过第一行。
4. 对于每一行,将Array转成List方便操作。将label列从list之中删除~
5. 对于sequencefile, key为label + row number, 并且,需要以"/"作为开头,否则在实际运行的时候会提示找不到key!
6. 对于sequencefile的value,使用一个Vector进行数据承载。在此使用的是RandomAccessSparseVector,可以试着使用DenseVector进行测试,看看是否在性能上会有所改善。
在用Bayes试过了好几种数据之后,感觉对于Bayes,最关键的一步其实是在这里,因为选择那些feature、原始数据如何预处理就在这里进行了,剩下的都是模板一样的代码~ 即使命令行也一样。
第三步: 训练Bayes
public static void train() throws Throwable { System.out.println("~~~ begin to train ~~~"); Configuration conf = new Configuration(); FileSystem fs = FileSystem.getLocal(conf); TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob(); trainNaiveBayes.setConf(conf); String outputDirectory = "/home/hadoop/DataSet/kdd99/bayes/output"; String tempDirectory = "/home/hadoop/DataSet/kdd99/bayes/temp"; fs.delete(new Path(outputDirectory),true); fs.delete(new Path(tempDirectory),true); // cmd sample: mahout trainnb -i train-vectors -el -li labelindex -o model -ow -c trainNaiveBayes.run(new String[] { "--input", trainSeqFile, "--output", outputDirectory, "-el", "--labelIndex", "labelIndex", "--overwrite", "--tempDir", tempDirectory }); // Train the classifier naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDirectory), conf); System.out.println("features: " + naiveBayesModel.numFeatures()); System.out.println("labels: " + naiveBayesModel.numLabels()); }
// cmd sample: mahout trainnb -i train-vectors -el -li labelindex -o model -ow -c
最后一步: 使用测试数据进行性能验证。
public static void test() throws IOException { System.out.println("~~~ begin to test ~~~"); AbstractVectorClassifier classifier = new ComplementaryNaiveBayesClassifier(naiveBayesModel); CSVReader csv = new CSVReader(new FileReader(testFile)); csv.readNext(); // skip header String[] line = null; double totalSampleCount = 0.; double correctClsCount = 0.; while((line = csv.readNext()) != null) { totalSampleCount ++; Vector vector = new RandomAccessSparseVector(40,40);//??? for(int i = 0; i < 40; i++) { if(StringUtils.isNumeric(line[i])) { vector.set(i, Double.parseDouble(line[i])); } else { Long id = strOptionMap.get(line[i]); if(id != null) vector.set(i, id); else { System.out.println(StringUtils.join(line, ",")); continue; } } } Vector resultVector = classifier.classifyFull(vector); int classifyResult = resultVector.maxValueIndex(); if(StringUtils.equals(line[41], strLabelList.get(classifyResult))) { correctClsCount++; } else { System.out.println("Correct=" + line[41] + "\tClassify=" + strLabelList.get(classifyResult) ); } } System.out.println("Correct Ratio:" + (correctClsCount / totalSampleCount)); }
PS: 全部java代码已经在附件之中,感兴趣的还请自取~