1. 概述

聚类(Clustering)是一类无监督学习算法的统称,核心目标是自动发现数据中彼此相似的群体——无论是物品、用户还是抽象概念。

虽然这句话看起来简单,但里面涉及不少机器学习中的关键概念:什么是聚类?什么是无监督学习?

本文会先厘清这些基础概念,然后手撸一个 Java 版的 K-Means 实现,并用 Last.fm 的音乐数据做实战验证。踩坑经验、实现细节一个不落,保证你读完就能上手。


2. 无监督学习算法

在使用大多数机器学习算法前,我们通常需要提供一批标注好的样本数据,让模型从中学习规律。这批数据在机器学习领域被称为训练数据(training data),整个学习过程称为训练(training)

根据训练过程中是否需要人工标注,学习算法可分为两大类:

  • 监督学习(Supervised Learning):训练数据包含明确的标签。比如训练垃圾邮件过滤器时,每封邮件都标注了“垃圾”或“非垃圾”。数学上,我们试图从已知的输入 x 和输出 y 中推导出映射函数 *f(x)*。
  • 无监督学习(Unsupervised Learning):训练数据没有标签。比如我们有一堆音乐人数据,目标是自动发现风格相近的艺人分组。K-Means 就是典型的无监督算法。

3. 聚类算法

聚类的目标是在没有先验标签的情况下,从数据中发现内在的结构和分组。它不依赖标注样本,而是通过数据点之间的相似性来形成簇(cluster)。

3.1 K-Means 算法特点

K-Means 是最经典的聚类算法之一,其最大特点是:必须预先指定聚类数量 k。除了 K-Means,还有层次聚类(Hierarchical Clustering)、谱聚类(Spectral Clustering)等无需预设 k 的算法。

3.2 K-Means 工作原理

假设我们有一组二维数据点,目标是将其划分为若干簇:

First Step

K-Means 的执行流程如下:

  1. 初始化:随机放置 k 个质心(Centroid),质心代表簇的中心位置。例如,k=4 时随机初始化四个质心:

    Random Centroids

  2. 分配:将每个数据点分配给距离最近的质心:

    Assignment

  3. 更新:重新计算每个簇的质心,即该簇所有点的均值坐标:

    Date 10

  4. 迭代:重复分配和更新步骤,直到质心位置不再显著变化(收敛):

    Date copy

最终算法收敛,形成稳定的 k 个簇。接下来我们用 Java 实现这个过程。

3.3 特征表示

要对数据建模,首先得定义特征结构。比如一个音乐人可能有“流派=摇滚”这样的属性。我们通常将“属性+值”的组合称为特征(feature)

为了便于计算相似度,我们会将特征转化为数值型向量。例如,用户对艺人打标签后,统计每个标签的出现次数:

Screen-Shot-1398-04-29-at-22.30.58

像 Linkin Park 这样的艺人,其特征向量可能是 [rock -> 7890, nu-metal -> 700, alternative -> 520, pop -> 3]。通过比较向量,就能量化艺人之间的相似性。

我们用 Java 类 Record 表示一个数据点:

public class Record {
    private final String description;
    private final Map<String, Double> features;

    // constructor, getter, toString, equals and hashCode
}

3.4 相似度计算

在每次迭代中,需要计算数据点与质心的距离。最简单的方式是欧几里得距离(Euclidean Distance)。对于两个向量 [p1, q1][p2, q2],其距离公式为:

4febdae84cbc320c19dd13eac5060a984fd438d8

我们先定义距离计算接口,便于后续扩展:

public interface Distance {
    double calculate(Map<String, Double> f1, Map<String, Double> f2);
}

欧几里得距离实现如下:

public class EuclideanDistance implements Distance {

    @Override
    public double calculate(Map<String, Double> f1, Map<String, Double> f2) {
        double sum = 0;
        for (String key : f1.keySet()) {
            Double v1 = f1.get(key);
            Double v2 = f2.get(key);

            if (v1 != null && v2 != null) {
                sum += Math.pow(v1 - v2, 2);
            }
        }

        return Math.sqrt(sum);
    }
}

⚠️ 注意:只计算两个向量共有的特征维度,避免空值干扰。

除了欧氏距离,还可以用皮尔逊相关系数(Pearson Correlation)等指标,接口设计让切换变得简单。

3.5 质心表示

质心与特征向量处于同一空间,因此结构类似:

public class Centroid {
    private final Map<String, Double> coordinates;

    // constructors, getter, toString, equals and hashCode
}

3.6 质心初始化

随机初始化质心时,若不考虑数据范围,可能导致收敛缓慢。最佳实践是:在每个特征的最小值和最大值之间生成随机坐标

实现步骤:

  1. 遍历数据集,统计每个特征的 min/max
  2. 在 [min, max] 区间内生成 k 个随机质心
private static List<Centroid> randomCentroids(List<Record> records, int k) {
    List<Centroid> centroids = new ArrayList<>();
    Map<String, Double> maxs = new HashMap<>();
    Map<String, Double> mins = new HashMap<>();

    for (Record record : records) {
        record.getFeatures().forEach((key, value) -> {
            maxs.compute(key, (k1, max) -> max == null || value > max ? value : max);
            mins.compute(key, (k1, min) -> min == null || value < min ? value : min);
        });
    }

    Set<String> attributes = records.stream()
      .flatMap(e -> e.getFeatures().keySet().stream())
      .collect(toSet());
    
    for (int i = 0; i < k; i++) {
        Map<String, Double> coordinates = new HashMap<>();
        for (String attribute : attributes) {
            double max = maxs.get(attribute);
            double min = mins.get(attribute);
            coordinates.put(attribute, random.nextDouble() * (max - min) + min);
        }
        centroids.add(new Centroid(coordinates));
    }

    return centroids;
}

3.7 数据点分配

给定一个数据点,找到距离最近的质心:

private static Centroid nearestCentroid(Record record, List<Centroid> centroids, Distance distance) {
    double minimumDistance = Double.MAX_VALUE;
    Centroid nearest = null;

    for (Centroid centroid : centroids) {
        double currentDistance = distance.calculate(record.getFeatures(), centroid.getCoordinates());
        if (currentDistance < minimumDistance) {
            minimumDistance = currentDistance;
            nearest = centroid;
        }
    }

    return nearest;
}

将数据点加入对应簇:

private static void assignToCluster(Map<Centroid, List<Record>> clusters,  
  Record record, 
  Centroid centroid) {
    clusters.compute(centroid, (key, list) -> {
        if (list == null) {
            list = new ArrayList<>();
        }
        list.add(record);
        return list;
    });
}

3.8 质心更新

若某质心无任何数据点归属,则保持原位置;否则,更新为该簇所有点的均值坐标:

private static Centroid average(Centroid centroid, List<Record> records) {
    if (records == null || records.isEmpty()) { 
        return centroid;
    }

    Map<String, Double> average = new HashMap<>();
    Set<String> keys = records.stream()
        .flatMap(e -> e.getFeatures().keySet().stream())
        .collect(toSet());
    
    // 初始化为 0
    keys.forEach(k -> average.put(k, 0.0));
    
    // 累加所有值
    for (Record record : records) {
        record.getFeatures().forEach(
          (k, v) -> average.compute(k, (k1, currentValue) -> v + currentValue)
        );
    }

    // 计算均值
    average.forEach((k, v) -> average.put(k, v / records.size()));
    return new Centroid(average);
}

批量更新所有质心:

private static List<Centroid> relocateCentroids(Map<Centroid, List<Record>> clusters) {
    return clusters.entrySet().stream()
        .map(e -> average(e.getKey(), e.getValue()))
        .collect(toList());
}

3.9 完整实现

整合所有步骤,主循环逻辑如下:

public static Map<Centroid, List<Record>> fit(List<Record> records, 
  int k, 
  Distance distance, 
  int maxIterations) {

    List<Centroid> centroids = randomCentroids(records, k);
    Map<Centroid, List<Record>> clusters = new HashMap<>();
    Map<Centroid, List<Record>> lastState = new HashMap<>();

    for (int i = 0; i < maxIterations; i++) {
        boolean isLastIteration = i == maxIterations - 1;

        // 分配所有点到最近质心
        for (Record record : records) {
            Centroid centroid = nearestCentroid(record, centroids, distance);
            assignToCluster(clusters, record, centroid);
        }

        // 判断是否收敛
        boolean shouldTerminate = isLastIteration || clusters.equals(lastState);
        lastState = new HashMap<>(clusters); // 深拷贝避免引用问题
        
        if (shouldTerminate) { 
            break; 
        }

        // 更新质心并清空当前簇
        centroids = relocateCentroids(clusters);
        clusters.clear();
    }

    return lastState;
}

✅ 关键点:每次迭代后清空 clusters,避免累积历史分配。


4. 实战:Last.fm 艺人聚类

Last.fm 记录用户听歌行为,可构建艺人画像。我们用其 API 获取数据并进行聚类。

4.1 API 调用

需申请 API Key(如 api_key=abc123xyz),使用 Retrofit 调用:

public interface LastFmService {

    @GET("/2.0/?method=chart.gettopartists&format=json&limit=50")
    Call<Artists> topArtists(@Query("page") int page);

    @GET("/2.0/?method=artist.gettoptags&format=json&limit=20&autocorrect=1")
    Call<Tags> topTagsFor(@Query("artist") String artist);

    @GET("/2.0/?method=chart.gettoptags&format=json&limit=100")
    Call<TopTags> topTags();
}

获取 Top 100 艺人:

private static List<String> getTop100Artists() throws IOException {
    List<String> artists = new ArrayList<>();
    for (int i = 1; i <= 2; i++) {
        artists.addAll(lastFm.topArtists(i).execute().body().all());
    }
    return artists;
}

获取 Top 100 标签:

private static Set<String> getTop100Tags() throws IOException {
    return lastFm.topTags().execute().body().all();
}

构建带标签频率的艺人数据集:

private static List<Record> datasetWithTaggedArtists(List<String> artists, 
  Set<String> topTags) throws IOException {
    List<Record> records = new ArrayList<>();
    for (String artist : artists) {
        Map<String, Double> tags = lastFm.topTagsFor(artist).execute().body().all();
        tags.entrySet().removeIf(e -> !topTags.contains(e.getKey()));
        records.add(new Record(artist, tags));
    }
    return records;
}

4.2 聚类结果

执行 K-Means(k=7):

List<String> artists = getTop100Artists();
Set<String> topTags = getTop100Tags();
List<Record> records = datasetWithTaggedArtists(artists, topTags);

Map<Centroid, List<Record>> clusters = KMeans.fit(records, 7, new EuclideanDistance(), 1000);

clusters.forEach((key, value) -> {
    System.out.println("-------------------------- CLUSTER ----------------------------");
    System.out.println(sortedCentroid(key)); 
    String members = String.join(", ", value.stream().map(Record::getDescription).collect(toSet()));
    System.out.print(members);
    System.out.println("\n");
});

输出示例:

------------------------------ CLUSTER -----------------------------------
Centroid {classic rock=65.58, rock=64.42, british=20.33, ... }
David Bowie, Led Zeppelin, Pink Floyd, System of a Down, Queen, ...

------------------------------ CLUSTER -----------------------------------
Centroid {Hip-Hop=97.21, rap=64.86, hip hop=29.29, ... }
Kanye West, Post Malone, Childish Gambino, Lil Nas X, ...

------------------------------ CLUSTER -----------------------------------
Centroid {indie rock=54.0, rock=52.0, Psychedelic Rock=51.0, ... }
Tame Impala, The Black Keys

------------------------------ CLUSTER -----------------------------------
Centroid {pop=81.96, female vocalists=41.29, indie=22.79, ... }
Ed Sheeran, Taylor Swift, Rihanna, Miley Cyrus, Billie Eilish, ...

------------------------------ CLUSTER -----------------------------------
Centroid {indie=95.23, alternative=70.62, indie rock=64.46, ... }
Twenty One Pilots, The Smiths, Florence + the Machine, ...

------------------------------ CLUSTER -----------------------------------
Centroid {electronic=91.69, House=39.46, dance=38.0, ... }
Charli XCX, The Weeknd, Daft Punk, Calvin Harris, ...

------------------------------ CLUSTER -----------------------------------
Centroid {rock=87.39, alternative=72.11, alternative rock=49.17, ... }
Weezer, The White Stripes, Nirvana, Foo Fighters, Maroon 5, ...

结果基本合理,但受用户标签主观性影响,存在一定噪声。


5. 可视化

将聚类结果转为 JSON,可用 D3.js 渲染为径向树图:

Screen-Shot-1398-05-04-at-12.09.40

结构需适配 D3 的 flare.json 格式,实现略。


6. 聚类数量 k 的选择

K-Means 的硬伤是必须预设 k。如何选?两种思路:

  • 领域知识:比如明确要分“摇滚、流行、电子”三类,直接设 k=3。
  • 数学启发式:如肘部法则(Elbow Method)、轮廓系数(Silhouette)。

6.1 肘部法则

核心思想:计算不同 k 值下的误差平方和(SSE),寻找 SSE 下降拐点。

SSE 定义:所有簇内,点到质心距离的平方和。

public static double sse(Map<Centroid, List<Record>> clustered, Distance distance) {
    double sum = 0;
    for (Map.Entry<Centroid, List<Record>> entry : clustered.entrySet()) {
        Centroid centroid = entry.getKey();
        for (Record record : entry.getValue()) {
            double d = distance.calculate(centroid.getCoordinates(), record.getFeatures());
            sum += Math.pow(d, 2);
        }
    }
    return sum;
}

遍历 k ∈ [2, 16],绘制 SSE 曲线:

Screen-Shot-1398-05-04-at-17.01.36

选择“肘部”位置(如 k=9),此处增加 k 带来的收益显著下降。

⚠️ 注意:肘部法则主观性强,建议结合业务场景判断。


7. 总结

本文从零实现了一个 Java 版 K-Means:

  • ✅ 讲清了无监督学习与聚类的核心概念
  • ✅ 手撸了特征表示、距离计算、质心迭代等核心模块
  • ✅ 用 Last.fm 数据实战,验证了算法有效性
  • ✅ 探讨了 k 值选择的实用方法

代码已开源至 GitHub:https://github.com/yourname/kmeans-java(示例链接)

K-Means 虽简单,但工程中够用。理解其原理后,再看 Spark MLlib 或 Sklearn 的实现,会轻松很多。


原始标题:The K-Means Clustering Algorithm in Java