k均值聚类算法(二)(k-means++)

前言

这次的文章,是笔者写的k均值算法的第二篇。在这篇文章里,笔者要给自己备忘一下k-means++算法的学习。

k-means++作用

笔者在上一篇文章写过k-means算法的相关实现,具体链接如下:k-means算法实现(c语言)
我们知道,k-means算法主要应用于数值型数据的聚类,它实现起来简单、高效,但是存在如下问题:

  1. 簇的数量难以确定(亦即聚类类别数量难以确定)
  2. 受初始点影响很大,一旦初始点不恰当选择,容易导致聚类的不正确。

因此,笔者今天备忘的k-means++算法,就是针对第二个问题而提出的一种算法。

k-means++实现步骤

在k-means算法实现的文章里,笔者已经介绍过k-means算法的实现步骤,如下:

  1. 初始化一定数量的二维坐标点(x,y),点数量可自定义,所有坐标点的初始类别都为0
  2. 根据自定义的类别数,比如说我需要把数据聚类成三类,则从上述坐标点中,随机取三个点。作为类别的中心点
  3. 迭代所有坐标点,分别与三个中心点计算平方距离(就是Δx^2+Δy^2),这一平方距离可以认为是与标准类别样本的差异度,取平方距离最小的类别作为该坐标点的类别
  4. 根据步骤三,可以得到每种分类的样本,然后,再从每种分类的样本中,重新计算该分类的类别中心。具体来说,比较简单的一种做法就是该类别所有样本的x与y,分别求和,然后除与该类别的样本数
  5. 迭代步骤3、4,直到每个类别的中心点不变或变化很小为止。

笔者今天要备忘的k-means++算法,除了步骤2有所不同外,其余步骤都一致。因为k-means++要解决k-means受初始化值影响大这个问题,因此k-means++算法中,最关键的一步就是初始点的选取应该足够的离散。这样才能保证在迭代过程中各个簇都能聚类到数量相当的数据。

那么如何选取初始点才算足够离散呢??答案当然是让相距较远的点有更大的几率成为簇中心点!!
因此第二个步骤中,选取簇中心(类别中心)的实现如下:

a. 首先,定义一个簇中心数组,接着随机选取一点,作为第一个簇中心。
b. 然后,迭代所有点,把所有点到该簇中心的距离算出,记录到距离数组中(D1、D2...Dn),分别代表第n个点到簇中心的最短距离。(值得注意的是,这里的最短距离是指该点离簇中心这一数组中所有簇中心点距离中的最短距离,也就是说,假设已经选出了两个簇中心,那么迭代所有点的时候,就要把每个点分别与两个簇中心计算距离,然后选择其中最短的距离作为该点的Dn)

c. 把所有的Dn加起来,随机一个不超过Dn的数r,然后迭代所有Dn,计算r-=Dn,直至r<0,此时的点即为新的簇中心点。

d. 迭代2、3步骤,直至选出所有簇中心。
个人理解,该方法选取簇中心的关键在于,计算出所有点到簇中心的距离和,然后若Dn越大,亦即该点距离簇中心越远,与我们随机出来的r相减少后,r就越有可能小于0,也就越有可能被选做簇中心。也就满足我们选择点离散这一条件。
值得注意的是在k-means++算法中,并不是直接选择距离最远的点作为新的簇中心,只是让这样的点被选做簇中心的概率更大而已。

实现(c++)

这次的实现与上篇文章类似,只是用类以及把簇中心的选取单独拿出来作为一个函数createCentroid()

#include <time.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#define max 100000000;
using namespace std;
//节点类
class node
{
  private:
    double x;
    double y;
    int centroid;

  public:
    node(double x, double y, int centroid);
    void setX(double x);
    void setY(double y);
    void setCentroid(int centroid);
    double getX();
    double getY();
    int getCentroid();
};
node::node(double x, double y, int centroid)
{
    this->x = x;
    this->y = y;
    this->centroid = centroid;
}
void node::setX(double x)
{
    this->x = x;
}
void node::setY(double y)
{
    this->y = y;
}
void node::setCentroid(int centroid)
{
    this->centroid = centroid;
}
double node::getX()
{
    return this->x;
}
double node::getY()
{
    return this->y;
}
int node::getCentroid()
{
    return this->centroid;
}

//随机生成二维数据点
node *randomNode(int len, int range)
{
    srand((unsigned)time(NULL));
    node *allNode = (node *)malloc(len * sizeof(node));
    for (int i = 0; i < len; ++i)
    {
        allNode[i] = node((double)rand() / 2147483647 * range, (double)rand() / 2147483647 * range, 0);
    }
    return allNode;
}
node *createCentroid(int len, int centroidNum, node *allNode)
{
    node *centroidNode = (node *)malloc(len * sizeof(node));
    //随机选择一个初始点
    int nowNum = 0;
    srand((unsigned)time(NULL));
    int firstNode = rand() % len;
    centroidNode[0] = node(allNode[firstNode].getX(), allNode[firstNode].getY(), nowNum);
    ++nowNum;
    while (nowNum < centroidNum)
    {
        int *Di = (int *)malloc(len * sizeof(int));
        for (int i = 0; i < len; ++i)
        { //初始化距离
            Di[i] = max;
        }
        //总距离
        double sumD=0;
        for (int i = 0; i < len; ++i)
        {
            for (int c = 0; c < nowNum; ++c)
            {
                double dis = pow(allNode[i].getX() - centroidNode[c].getX(), 2) + pow(allNode[i].getY() - centroidNode[c].getY(), 2);
                if (dis < Di[i])
                    Di[i] = dis;
            }
            sumD+=Di[i];
        }
        //选点距离
        double r=(double)rand() / 2147483647 * sumD;
        int i=0;
        for(;r>0;i++){
            r-=Di[i];
        }
        centroidNode[nowNum]=allNode[i];
        ++nowNum;
    }
    return centroidNode;
}
//输出所有数据点
void printAll_node(int len, node *allNode)
{
    for (int i = 0; i < len; ++i)
    {
        printf("%d,(%f %f)\n", allNode[i].getCentroid(), allNode[i].getX(), allNode[i].getY());
    }
}
void kMeans_plus(int centroidNum, int len,node *allNode)
{
    //随机生成簇中心点
    node *centroidNode =createCentroid(len,centroidNum,allNode);
    for (int t = 0; t < len; ++t)
    {
        double sum = 0;
        //对每个数据点进行分类
        for (int i = 0; i < len; ++i)
        {
            int dis = max;
            int newDis;
            for (int x = 0; x < centroidNum; ++x)
            {
                //遍历所有簇中心取最近的簇
                newDis = pow((allNode[i].getX() - centroidNode[x].getX()), 2) + pow((allNode[i].getY() - centroidNode[x].getY()), 2);
                if (newDis < dis)
                {
                    dis = newDis;
                    allNode[i].setCentroid(x);
                }
            }
        }
        //重新计算簇中心
        //每个簇的计数器
        int *clusterS_num =(int*)malloc(centroidNum * sizeof(int));
        double *clusterS_sumX =(double*)malloc(centroidNum * sizeof(double));
        double *clusterS_sumY = (double*)malloc(centroidNum * sizeof(double));
        for (int i = 0; i < centroidNum; ++i)
        {
            //初始化计数器
            clusterS_num[i] = clusterS_sumX[i] = clusterS_sumY[i] = 0;
        }
        for (int i = 0; i < len; i++)
        {
            clusterS_num[allNode[i].getCentroid()]++;
            clusterS_sumX[allNode[i].getCentroid()] += allNode[i].getX();
            clusterS_sumY[allNode[i].getCentroid()] += allNode[i].getY();
        }
        //重新计算簇中心
        for (int i = 0; i < centroidNum; ++i)
        {
            centroidNode[i].setX(clusterS_sumX[i] / clusterS_num[i]);
            centroidNode[i].setY(clusterS_sumY[i] / clusterS_num[i]);
        }
        for (int i = 0; i < len; ++i)
        {
            sum = sum + pow((allNode[i].getX() - centroidNode[allNode[i].getCentroid()].getX()), 2) + pow((allNode[i].getY() - centroidNode[allNode[i].getCentroid()].getY()), 2);
        }
        free(clusterS_num);
        free(clusterS_sumX);
        free(clusterS_sumY);
    }
}
int main()
{
    node *allNode = randomNode(500, 1000);
    kMeans_plus(5,500,allNode);
    printAll_node(500, allNode);
}

相关推荐