[Mahout] 使用Mahout 对Kddcup 1999的数据进行分析 -- Naive Bayes

通常,在网上找到的mahout的naive bayes的例子跟官网的例子,都是针对20 newsgroup. 而且通常是命令行版本。虽然能得出预测、分类结果,但是对于Bayes具体是如何工作,以及如何处理自己的数据会比较茫然。

在努力了差不多一个星期之后,终于有点成果。

这个例子就是使用mahout 0.9 对kddcup 1999 的数据进行分析。

第一步: 下载数据。

地址: http://kdd.ics.uci.edu/databases/kddcup99/


[Mahout] 使用Mahout 对Kddcup 1999的数据进行分析 -- Naive Bayes
 

关于数据的一些简单的预处理,我们会在第二步进行。细心的你可能发现,有些数据是2007年上传的!这是因为有一些数据原来的标记有错误,后来进行了更正。

第二步: 将原始文件转换成Hadoop使用的sequence 文件。 

我们从官网知道,Bayes在mahout之中只有基于map-reduce的实现。 参考: https://mahout.apache.org/users/basics/algorithms.html 所以我们必须要将csv文件转换成hadoop使用的sequence文件

 
[Mahout] 使用Mahout 对Kddcup 1999的数据进行分析 -- Naive Bayes
 

先贴一下代码:(注意:这里列的代码,仅仅用于说明流程,并没有注意性能方面的考虑。处理过大的文件的时候,需要有针对性的自行进行调整~)

package experiment.kdd99_bayes;

import java.io.FileReader;
import java.io.IOException;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;

import au.com.bytecode.opencsv.CSVReader;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

public class Kdd99CsvToSeqFile {

	private String csvPath;
	private Path seqPath;
	private SequenceFile.Writer writer;
	private Configuration conf = new Configuration();
	private Map<String, Long> word2LongMap = Maps.newHashMap();
	private List<String> strLabelList = Lists.newArrayList();
	private FileSystem fs = null;
	
	public Kdd99CsvToSeqFile(String csvFilePath, String seqPath) {
		this.csvPath = csvFilePath;
		this.seqPath = new Path(seqPath);
	}
	
	public Map<String, Long> getWordMap() {
		return word2LongMap;
	}
	
	public List<String> getLabelList() {
		return strLabelList;
	}
	
	/**
	 * Show out the already sequenced file content
	 */
	public void dump() {
		try {
			fs = FileSystem.get(conf);
			SequenceFile.Reader reader = new SequenceFile.Reader(fs, this.seqPath, conf);
			Text key = new Text();
			VectorWritable value = new VectorWritable();
			while (reader.next(key, value)) {
				System.out.println( "reading key:" + key.toString() +" with value " +
						value.toString());
			}
			reader.close();
		} catch (IOException e) {
			e.printStackTrace();
		} finally {
			try {
				fs.close();
				fs = null;
			} catch (IOException e) {
				e.printStackTrace();
			}
		}
	}
	
	/**
	 * Sequence target csv file.
	 * @param labelIndex
	 * @param hasHeader
	 */
	public void parse(int labelIndex, boolean hasHeader) {
		CSVReader reader = null;
		try {
			fs = FileSystem.getLocal(conf);
			if(fs.exists(this.seqPath))
				fs.delete(this.seqPath, true);
			writer = SequenceFile.createWriter(fs, conf, this.seqPath, Text.class, VectorWritable.class);
			reader = new CSVReader(new FileReader(this.csvPath));
			String[] header = null;
			if(hasHeader) header = reader.readNext();
			String[] line = null;
			Long l = 0L;
			while((line = reader.readNext()) != null) {
				if(labelIndex > line.length) break;
				l++;
				List<String> tmpList = Lists.newArrayList(line);
				String label = tmpList.get(labelIndex);
				if(!strLabelList.contains(label)) strLabelList.add(label);
//				Text key = new Text("/" + label + "/" + l);
				Text key = new Text("/" + label + "/");
				tmpList.remove(labelIndex);
				
				VectorWritable vectorWritable = new VectorWritable();
				Vector vector = new RandomAccessSparseVector(tmpList.size(), tmpList.size());//???
				
				for(int i = 0; i < tmpList.size(); i++) {
					String tmpStr = tmpList.get(i);
					if(StringUtils.isNumeric(tmpStr))
						vector.set(i, Double.parseDouble(tmpStr));
					else 
						vector.set(i, parseStrCell(tmpStr)); 
				}
				vectorWritable.set(vector);
				writer.append(key, vectorWritable);
			}
			
		} catch (IOException e) {
			e.printStackTrace();
		} finally {
			try {
				fs.close();
				fs = null;
				writer.close();
				reader.close();
			} catch (IOException e) {
				e.printStackTrace();
			}
		}
	}
	
	private Long parseStrCell(String str) {
		Long id = word2LongMap.get(str); 
		if( id == null) {
			id = (long) (word2LongMap.size() + 1);
			word2LongMap.put(str, id);
		} 
		return id;
	}	
}

说明一下这个代码的工作流程:

1. 初始化hadoop,比如Configuration 、 FileSystem。 

2. 通过Hadoop的 Sequence.Writer进行sequence文件的写入。其中的key/value 分别是Text 跟VectorWritable类型。

3. 通过CSVReader读入CSV文件,然后逐行遍历。如果是带标题的,则先略过第一行。

4. 对于每一行,将Array转成List方便操作。将label列从list之中删除~

5. 对于sequencefile, key为label + row number, 并且,需要以"/"作为开头,否则在实际运行的时候会提示找不到key!

6. 对于sequencefile的value,使用一个Vector进行数据承载。在此使用的是RandomAccessSparseVector,可以试着使用DenseVector进行测试,看看是否在性能上会有所改善。

在用Bayes试过了好几种数据之后,感觉对于Bayes,最关键的一步其实是在这里,因为选择那些feature、原始数据如何预处理就在这里进行了,剩下的都是模板一样的代码~ 即使命令行也一样。

第三步: 训练Bayes

在这里仅仅先贴出训练部分的代码,整体的代码最后上传

public static void train() throws Throwable	{
		System.out.println("~~~ begin to train ~~~");
		Configuration conf = new Configuration();
		FileSystem fs = FileSystem.getLocal(conf);
		TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();
		trainNaiveBayes.setConf(conf);
		
		String outputDirectory = "/home/hadoop/DataSet/kdd99/bayes/output";
		String tempDirectory = "/home/hadoop/DataSet/kdd99/bayes/temp";
		
		fs.delete(new Path(outputDirectory),true);
		fs.delete(new Path(tempDirectory),true);
		// cmd sample: mahout trainnb -i train-vectors -el -li labelindex -o model -ow -c
		trainNaiveBayes.run(new String[] { 
				"--input", trainSeqFile, 
				"--output", outputDirectory,
				"-el", 
				"--labelIndex", "labelIndex",
				"--overwrite", 
				"--tempDir", tempDirectory });
		
		// Train the classifier
		naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDirectory), conf);

		System.out.println("features: " + naiveBayesModel.numFeatures());
		System.out.println("labels: " + naiveBayesModel.numLabels());
	}

从上面的代码可以看到,熟悉命令行之后,在实际java代码编写的时候,传入进去的也是一些命令行参数。

(可能有其他方法,只是目前我还不了解~)

命令行:

// cmd sample: mahout trainnb -i train-vectors -el -li labelindex -o model -ow -c

Java代码:

trainNaiveBayes.run

最后一步: 使用测试数据进行性能验证。

public static void test() throws IOException {
		System.out.println("~~~ begin to test ~~~");
	    AbstractVectorClassifier classifier = new ComplementaryNaiveBayesClassifier(naiveBayesModel);
	    
	    CSVReader csv = new CSVReader(new FileReader(testFile));
	    csv.readNext(); // skip header
	    String[] line = null;
	    double totalSampleCount = 0.;
	    double correctClsCount = 0.;
	    while((line = csv.readNext()) != null) {
	    	totalSampleCount ++;
	    	Vector vector = new RandomAccessSparseVector(40,40);//???
	    	for(int i = 0; i < 40; i++) {
	    		if(StringUtils.isNumeric(line[i])) {
	    			vector.set(i, Double.parseDouble(line[i]));
	    		} else {
	    			Long id = strOptionMap.get(line[i]);
	    			if(id != null)
	    				vector.set(i, id);
	    			else {
	    				System.out.println(StringUtils.join(line, ","));
	    				continue;
	    			}
	    		}
	    	}
    		Vector resultVector = classifier.classifyFull(vector);
			int classifyResult = resultVector.maxValueIndex();
			if(StringUtils.equals(line[41], strLabelList.get(classifyResult))) {
		    	correctClsCount++;
		    } else {
		    	System.out.println("Correct=" + line[41] + "\tClassify=" + strLabelList.get(classifyResult) );
		    }
		    
	    }
	    
	    System.out.println("Correct Ratio:" + (correctClsCount / totalSampleCount));    	}

 可以看到上面的加粗部分,用的是ComplementaryNaiveBayesClassifier,另外一个贝叶斯分类器就是

StandardNaiveBayesClassifier

最后运算的结果不太好,仅有约63%的正确率~

大家可以参考下面使用Bayes对Tweet进行分类的例子,正确率能有98%这样!当然,需要各位有过功夫网的本领了~

chimpler.wordpress.com/2013/03/13/using-the-mahout-naive-bayes-classifier-to-automatically-classify-twitter-messages/

 PS: 全部java代码已经在附件之中,感兴趣的还请自取~

相关推荐