使用Python和Doc2vec进行文本聚类
让我们假设您有一堆来自用户的文本文档,并希望从中获得一些见解。例如,如果您是市场,则可以对某些商品进行数百万次评论。另一个可能的情况是,日常用户使用您的服务创建文本文档,并且您希望将这些文档分类到某些组中,然后向用户提出这些预测类型。听起来很酷,不是吗?
问题是您事先不知道文档类型:它可能从10到数千个可能的类不等。当然,您不希望手动执行此操作。令人高兴的是,我们可以使用简单的Python代码来聚类这些文档,然后分析预测的cluster。
什么是聚类?
聚类 - 用于将类似项目分组到一个组中的无监督技术。至于文本,我们可以创建整个文本语料库的嵌入,然后比较每个句子或文本的向量(取决于您使用的嵌入)与余弦相似性。
好的,但是文本嵌入是什么?词嵌入是文本的学习表示,其中具有相同含义的单词具有相似的表示。正是这种表达单词和文档的方法可能被认为是深度学习挑战自然语言处理问题的关键突破之一。
Python代码
首先,让我们导入所有必要的Python库
import pickle
import pandas as pd
import numpy
import re
import os
import numpy as np
import gensim
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from gensim.models import Doc2Vec
然后,假设我们有一个.csv文件,我们保存了文本文档。
train= pd.read_csv(‘train.csv’)
现在我们有训练数据集,我们可以用它来创建文本嵌入。同样,在我们的例子中,一个项目是文本,我们将使用文本级嵌入 - Doc2vec。
首先,让我们准备我们的数据。我假设所有文本信息都存储在我们数据集的text文本列中。Doc2vec要求以特定方式准备文本,所以让我们为它编写一个简单的Python代码。
LabeledSentence1 = gensim.models.doc2vec.TaggedDocument
all_content_train = []
j=0
for em in train[‘text’].values:
all_content_train.append(LabeledSentence1(em,[j]))
j+=1
print(“Number of texts processed: “, j)
然后让我们定义和训练我们的模型(根据您的系统功能需要一些时间)
d2v_model = Doc2Vec(all_content_train, size = 100, window = 10, min_count = 500, workers=7, dm = 1,alpha=0.025, min_alpha=0.001)
d2v_model.train(all_content_train, total_examples=d2v_model.corpus_count, epochs=10, start_alpha=0.002, end_alpha=-0.016)
现在我们已经训练了嵌入,现在是时候对它进行聚类了。
kmeans_model = KMeans(n_clusters=4, init=’k-means++’, max_iter=100)
X = kmeans_model.fit(d2v_model.docvecs.doctag_syn0)
labels=kmeans_model.labels_.tolist()
l = kmeans_model.fit_predict(d2v_model.docvecs.doctag_syn0)
pca = PCA(n_components=2).fit(d2v_model.docvecs.doctag_syn0)
datapoint = pca.transform(d2v_model.docvecs.doctag_syn0)
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure
label1 = [“#FFFF00”, “#008000”, “#0000FF”, “#800080”]
color = [label1[i] for i in labels]
plt.scatter(datapoint[:, 0], datapoint[:, 1], c=color)
centroids = kmeans_model.cluster_centers_
centroidpoint = pca.transform(centroids)
plt.scatter(centroidpoint[:, 0], centroidpoint[:, 1], marker=’^’, s=150, c=’#000000')
plt.show()
这里,我选择了4个clusters来显示,并这样绘制
很容易看到,我们所有的数据都可以被简单地划分成cluster。