机器学习&数据挖掘

聚类算法Mean Shift

Author
宋天龙
发布于 2015-05-20
2274 次阅读
0 次赞
0 次分享
聚类算法Mean Shift
AI 智能核心导读

Mean Shift 是一种基于核密度估计的无参迭代算法,通过沿概率密度梯度方向寻找数据最密集区域。该算法收敛快、鲁棒性强,广泛应用于目标跟踪、图像分割与聚类等领域。在 Python 实战中,自动计算 bandwidth 参数易成为大数据集下的性能瓶颈,建议结合先验经验手动指定以保障实时计算效率。

深入解析 Mean Shift 算法:原理、应用与 Python 实战

算法概述与核心原理

Mean Shift 算法,一般是指一个迭代的步骤,即先算出当前点的偏移均值,然后以此为新的起始点,继续移动,直到满足一定的结束条件。Mean Shift 算法是一种无参密度估计算法或称核密度估计算法。Mean Shift 是一个向量,它的方向指向当前点上概率密度梯度的方向。

所谓的核密度评估算法,指的是根据数据概率密度不断移动其均值质心(也就是算法的名称 Mean Shift 的含义),直到满足一定条件。

mean-shift11
mean-shift11

上图诠释了 Mean Shift 算法的基本工作原理。那么,如何找到数据概率密度最大的区域?

数据最密集的地方,对应于概率密度最大的地方。我们可以对概率密度求梯度,梯度的方向就是概率密度增加最大的方向,从而也就是数据最密集的方向。

在目标跟踪中的应用与优缺点

Mean Shift 算法最常用于目标跟踪。它通过计算候选目标与目标模板之间相似度的概率密度分布,然后利用概率密度梯度下降的方向来获取匹配搜索的最佳路径,加速运动目标的定位和降低搜索的时间,因此在目标实时跟踪领域有着很高的应用价值。

算法优点

  • 由于采用了统计特征,因此对噪声有很强的鲁棒性
  • 由于是一个单参数算法,容易作为一个模块和别的算法集成;
  • 采用核函数直方图建模,对边缘阻挡、目标的旋转、变形以及背景运动都不敏感;
  • 算法构造了一个可以用 Mean Shift 算法进行寻优的相似度函数。由于 Mean Shift 本质上是最陡下降法,因此其寻优过程收敛速度快,使得该算法具有很好的实时性。

算法缺点

  • 缺乏必要的模板更新;
  • 跟踪过程中由于窗口宽度大小保持不变,当目标尺度有所变化时,跟踪就会失败;
  • 当目标速度较快时,跟踪效果不好;
  • 直方图特征在目标颜色特征描述方面略显匮乏,缺少空间信息。

Mean Shift 的主要应用领域

Mean Shift 算法在很多领域都有成功应用,例如图像平滑、图像分割、物体跟踪等,这些属于人工智能里面模式识别或计算机视觉的部分;另外也包括常规的聚类应用。

  • 图像平滑:图像最大质量下的像素压缩;
  • 图像分割:跟图像平滑类似的应用,但最终是将可以平滑的图像进行分离,以达到前后景或固定物理分割的目的;
  • 目标跟踪:例如针对监控视频中某个人物的动态跟踪;
  • 常规聚类:如用户聚类等。

基于 Python 的实战演示

下面基于 Python 的机器学习库 scikit-learn(SKlearn)中的 MeanShift 演示算法应用。

代码实现

python
1import numpy as np 2from sklearn.cluster import MeanShift, estimate_bandwidth 3from sklearn.datasets.samples_generator import make_blobs 4 5# 生成样本点 6centers = [[1, 1], [-1, -1], [1, -1]] 7X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6) 8 9# 通过下列代码可自动检测 bandwidth 值 10bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500) 11 12ms = MeanShift(bandwidth=bandwidth, bin_seeding=True) 13ms.fit(X) 14labels = ms.labels_ 15cluster_centers = ms.cluster_centers_ 16 17labels_unique = np.unique(labels) 18n_clusters_ = len(labels_unique) 19new_X = np.column_stack((X, labels)) 20 21print("number of estimated clusters : %d" % n_clusters_) 22print("Top 10 samples:\n", new_X[:10]) 23 24# 图像输出 25import matplotlib.pyplot as plt 26from itertools import cycle 27 28plt.figure(1) 29plt.clf() 30 31colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk') 32for k, col in zip(range(n_clusters_), colors): 33 my_members = labels == k 34 cluster_center = cluster_centers[k] 35 plt.plot(X[my_members, 0], X[my_members, 1], col + '.') 36 plt.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=col, markeredgecolor='k', markersize=14) 37 plt.title('Estimated number of clusters: %d' % n_clusters_) 38plt.show()

执行结果

code
1number of estimated clusters : 3 2('Top 10 samples:', array( 3[[ 1.8141499 , -1.45580736, 0. ], 4 [-0.66658907, -0.29515074, 2. ], 5 [-1.49755338, -0.96610942, 2. ], 6 [ 0.34816411, -0.69885676, 0. ], 7 [ 1.80841958, 2.14678071, 1. ], 8 [ 1.02185502, -0.9430071 , 0. ], 9 [ 1.58717372, -0.85057434, 0. ], 10 [ 2.25539903, -0.22049871, 0. ], 11 [ 0.30516472, -1.6391161 , 0. ], 12 [ 0.49500464, -0.76833638, 0. ]]))

QQ截图2015052017195511
QQ截图2015052017195511

核心参数配置

MeanShift 可配置的参数中,重点是 bandwidth 值的设置:

python
1class sklearn.cluster.MeanShift(bandwidth=None, seeds=None, bin_seeding=False, min_bin_freq=1, cluster_all=True)

总结

在上述实现算法中,我们强调了 Mean Shift 具有很好的实时计算性。但由于 Python 中的该算法在默认情况下会使用 sklearn.cluster.estimate_bandwidth 函数进行自动计算 bandwidth 值,而该函数的可扩展性将会成为 Mean Shift 在大量数据集下应用实时性的瓶颈。

当然,解决方法是不使用其默认的 bandwidth 计算函数,而是自己指定一个数值。这就要求操作人员对原始数据集、算法和应用场景有比较好的先验经验,一定程度上提高了应用要求。

分享
最后修订: 2015-05-20