k-means算法求解anchors
文字内容以后再补充:
import numpy as np# 定义Box类,描述bounding box的坐标class Box(): def __init__(self, x, y, w, h): self.x = x self.y = y self.w = w self.h = hdef box_iou(a, b): ‘‘‘ # a和b都是Box类型实例 # 返回值area是box a 和box b 的交集面积 ‘‘‘ a_x1 = a.x-a.w/2 a_y1 = a.y - a.h / 2 a_x2 = a.x+a.w/2 a_y2 = a.y + a.h / 2 b_x1 = b.x-b.w/2 b_y1 = b.y - b.h / 2 b_x2 = b.x+b.w/2 b_y2 = b.y + b.h / 2 box_x1 = max(a_x1,b_x1) box_y1 = max(a_y1, b_y1) box_x2 = min(a_x2,b_x2) box_y2 = min(a_y2, b_y2) box_w = box_x2-box_x1 box_h = box_y2 - box_y1 if box_w < 0 or box_h < 0: area = 0 else: area = box_w * box_h box_intersection=area box_union = a.w * a.h + b.w * b.h-box_intersection iou = box_intersection/box_union return iou# 使用k-means ++ 初始化 centroids,减少随机初始化的centroids对最终结果的影响def init_centroids(boxes, n_anchors): ‘‘‘ 随机选择一个box作为 :param boxes: 是所有bounding boxes的Box对象列表 :param n_anchors: n_anchors是k-means的k值 :return: 返回值centroids 是初始化的n_anchors个centroid ‘‘‘ centroids = [] boxes_num = len(boxes) centroid_index = np.random.choice(boxes_num, 1) # 在boxes_num=55 中产生一个数23 centroids.append(boxes[centroid_index]) print(centroids[0].w, centroids[0].h) for centroid_index in range(0, n_anchors-1): sum_distance = 0 distance_list = [] cur_sum = 0 for box in boxes: min_distance = 1 for centroid_i, centroid in enumerate(centroids): distance = (1 - box_iou(box, centroid)) if distance < min_distance: min_distance = distance sum_distance += min_distance distance_list.append(min_distance) distance_thresh = sum_distance*np.random.random() for i in range(0, boxes_num): cur_sum += distance_list[i] if cur_sum > distance_thresh: centroids.append(boxes[i]) print(boxes[i].w, boxes[i].h) break return centroids# 进行 k-means 计算新的centroidsdef do_kmeans(n_anchors, boxes, centroids): ‘‘‘ :param n_anchors: 是k-means的k值 :param boxes: 是所有bounding boxes的Box对象列表 :param centroids: 是所有簇的中心 :return: # 返回值new_centroids 是计算出的新簇中心 # 返回值groups是n_anchors个簇包含的boxes的列表 # 返回值loss是所有box距离所属的最近的centroid的距离的和 ‘‘‘ loss = 0 groups = [] new_centroids = [] for i in range(n_anchors): groups.append([]) # [[], [], [], []] new_centroids.append(Box(0, 0, 0, 0)) # 以上代码建立初始化 for box in boxes: min_distance = 1 group_index = 0 for centroid_index, centroid in enumerate(centroids): # 这个循环实际是在找box与哪个centroidsiou最小,最接近 distance = (1 - box_iou(box, centroid)) if distance < min_distance: min_distance = distance group_index = centroid_index groups[group_index].append(box) # 将其保留对应的族中 loss += min_distance new_centroids[group_index].w += box.w # 累加对应的族中的w new_centroids[group_index].h += box.h for i in range(n_anchors): # 得到新的族中的w与h new_centroids[i].w /= len(groups[i]) new_centroids[i].h /= len(groups[i]) return new_centroids, groups, lossdef init_all_value(use_init_centroids=1, n_anchors=9): # 构建初始化族中心 if use_init_centroids: centroids = init_centroids(boxes, n_anchors) else: centroid_indices = np.random.choice(len(boxes), n_anchors) centroids = [] for centroid_index in centroid_indices: centroids.append(boxes[centroid_index]) # 构建初始化 groups 保存对应族的box类 centroids, groups, old_loss = do_kmeans(n_anchors, boxes, centroids) return centroids, groups, old_lossif __name__==‘__main__‘: # 构建boxes boxes=[] boxes.append(Box(4,5,6,7)) # 根据实际情况自己填写 num_anchor = 9 # 产生族中心点是多少 # 构建停止条件 num_iterations=2000 loss_stop = 1e-6 centroids, groups, old_loss = init_all_value(1, num_anchor) # 循环找到族中最好的w与h iterations = 1 while (True): centroids, groups, loss = do_kmeans(num_anchor, boxes, centroids) iterations = iterations + 1 print("loss = %f" % loss) if abs(old_loss - loss) < loss_stop or iterations > num_iterations: break old_loss = loss # 打印最终结果 for centroid in centroids: print("k-means result:\n") print(centroid.w, centroid.h)
相关推荐
zhanghao 2020-06-16
tichangde 2020-05-19
xsgnzb 2020-03-09
bertzhang 2019-12-25
RexLeee 2019-12-23
MaureenChen 2019-12-03
wangjie 2019-11-10
谢恩铭 2016-01-10
时间猎人 2019-09-06
Jolestar 2016-06-02
gigipop 2019-07-01
标题 2019-07-01
jianqi 2019-07-01
RLanffy 2019-07-01
VVVinegar 2019-07-01
GeolageWu 2019-06-30
CoderChang 2019-06-30
zeweig 2015-06-29