k-means算法求解anchors

文字内容以后再补充:

import numpy as np# 定义Box类,描述bounding box的坐标class Box():    def __init__(self, x, y, w, h):        self.x = x        self.y = y        self.w = w        self.h = hdef box_iou(a, b):    ‘‘‘    # a和b都是Box类型实例    # 返回值area是box a 和box b 的交集面积    ‘‘‘    a_x1 = a.x-a.w/2    a_y1 = a.y - a.h / 2    a_x2 = a.x+a.w/2    a_y2 = a.y + a.h / 2    b_x1 = b.x-b.w/2    b_y1 = b.y - b.h / 2    b_x2 = b.x+b.w/2    b_y2 = b.y + b.h / 2    box_x1 = max(a_x1,b_x1)    box_y1 = max(a_y1, b_y1)    box_x2 = min(a_x2,b_x2)    box_y2 = min(a_y2, b_y2)    box_w = box_x2-box_x1    box_h = box_y2 - box_y1    if box_w < 0 or box_h < 0:        area = 0    else:        area = box_w * box_h    box_intersection=area    box_union = a.w * a.h + b.w * b.h-box_intersection    iou = box_intersection/box_union    return iou# 使用k-means ++ 初始化 centroids,减少随机初始化的centroids对最终结果的影响def init_centroids(boxes, n_anchors):    ‘‘‘    随机选择一个box作为    :param boxes: 是所有bounding boxes的Box对象列表    :param n_anchors: n_anchors是k-means的k值    :return: 返回值centroids 是初始化的n_anchors个centroid    ‘‘‘    centroids = []    boxes_num = len(boxes)    centroid_index = np.random.choice(boxes_num, 1)  # 在boxes_num=55 中产生一个数23    centroids.append(boxes[centroid_index])    print(centroids[0].w, centroids[0].h)    for centroid_index in range(0, n_anchors-1):        sum_distance = 0        distance_list = []        cur_sum = 0        for box in boxes:            min_distance = 1            for centroid_i, centroid in enumerate(centroids):                distance = (1 - box_iou(box, centroid))                if distance < min_distance:                    min_distance = distance            sum_distance += min_distance            distance_list.append(min_distance)        distance_thresh = sum_distance*np.random.random()        for i in range(0, boxes_num):            cur_sum += distance_list[i]            if cur_sum > distance_thresh:                centroids.append(boxes[i])                print(boxes[i].w, boxes[i].h)                break    return centroids# 进行 k-means 计算新的centroidsdef do_kmeans(n_anchors, boxes, centroids):    ‘‘‘    :param n_anchors: 是k-means的k值    :param boxes: 是所有bounding boxes的Box对象列表    :param centroids: 是所有簇的中心    :return: # 返回值new_centroids 是计算出的新簇中心             # 返回值groups是n_anchors个簇包含的boxes的列表            # 返回值loss是所有box距离所属的最近的centroid的距离的和    ‘‘‘    loss = 0    groups = []    new_centroids = []    for i in range(n_anchors):        groups.append([])  #  [[], [], [], []]        new_centroids.append(Box(0, 0, 0, 0))    # 以上代码建立初始化    for box in boxes:        min_distance = 1        group_index = 0        for centroid_index, centroid in enumerate(centroids):            # 这个循环实际是在找box与哪个centroidsiou最小,最接近            distance = (1 - box_iou(box, centroid))            if distance < min_distance:                min_distance = distance                group_index = centroid_index        groups[group_index].append(box)  # 将其保留对应的族中        loss += min_distance        new_centroids[group_index].w += box.w  # 累加对应的族中的w        new_centroids[group_index].h += box.h    for i in range(n_anchors):  # 得到新的族中的w与h        new_centroids[i].w /= len(groups[i])        new_centroids[i].h /= len(groups[i])    return new_centroids, groups, lossdef init_all_value(use_init_centroids=1, n_anchors=9):    # 构建初始化族中心    if use_init_centroids:        centroids = init_centroids(boxes, n_anchors)    else:        centroid_indices = np.random.choice(len(boxes), n_anchors)        centroids = []        for centroid_index in centroid_indices:            centroids.append(boxes[centroid_index])    # 构建初始化 groups 保存对应族的box类    centroids, groups, old_loss = do_kmeans(n_anchors, boxes, centroids)    return centroids, groups, old_lossif __name__==‘__main__‘:    # 构建boxes    boxes=[]    boxes.append(Box(4,5,6,7))  # 根据实际情况自己填写    num_anchor = 9  # 产生族中心点是多少    # 构建停止条件    num_iterations=2000    loss_stop = 1e-6    centroids, groups, old_loss = init_all_value(1, num_anchor)    # 循环找到族中最好的w与h    iterations = 1    while (True):        centroids, groups, loss = do_kmeans(num_anchor, boxes, centroids)        iterations = iterations + 1        print("loss = %f" % loss)        if abs(old_loss - loss) < loss_stop or iterations > num_iterations:            break        old_loss = loss    # 打印最终结果    for centroid in centroids:        print("k-means result:\n")        print(centroid.w, centroid.h)

相关推荐