1. 概述

在本教程中,我们将探讨如何在 Java 中实现两个矩阵的乘法运算。

由于 Java 本身并未原生支持矩阵数据结构,因此我们会手动实现一个基础版本,并使用几个流行的库来完成矩阵乘法。最后,我们还将对不同实现方式进行性能基准测试,看看哪种方式最快。

2. 示例说明

我们以一个 3×2 的矩阵为例:

firstMatrix

再来看一个 2×4 的矩阵:

secondMatrix

将它们相乘后,得到一个 3×4 的结果矩阵:

resultMatrix

矩阵乘法的基本公式如下:

multiplicationAlgorithm

其中 r 是矩阵 A 的行数,c 是矩阵 B 的列数,n 是矩阵 A 的列数(必须等于矩阵 B 的行数)。

3. 矩阵乘法实现

3.1. 自定义实现

我们使用二维 double 数组来表示矩阵:

double[][] firstMatrix = {
  new double[]{1d, 5d},
  new double[]{2d, 3d},
  new double[]{1d, 7d}
};

double[][] secondMatrix = {
  new double[]{1d, 2d, 3d, 7d},
  new double[]{5d, 2d, 8d, 1d}
};

定义预期结果:

double[][] expected = {
  new double[]{26d, 12d, 43d, 12d},
  new double[]{17d, 10d, 30d, 17d},
  new double[]{36d, 16d, 59d, 14d}
};

接下来是矩阵乘法的实现:

double[][] multiplyMatrices(double[][] firstMatrix, double[][] secondMatrix) {
    double[][] result = new double[firstMatrix.length][secondMatrix[0].length];

    for (int row = 0; row < result.length; row++) {
        for (int col = 0; col < result[row].length; col++) {
            result[row][col] = multiplyMatricesCell(firstMatrix, secondMatrix, row, col);
        }
    }

    return result;
}

计算单个单元格的值:

double multiplyMatricesCell(double[][] firstMatrix, double[][] secondMatrix, int row, int col) {
    double cell = 0;
    for (int i = 0; i < secondMatrix.length; i++) {
        cell += firstMatrix[row][i] * secondMatrix[i][col];
    }
    return cell;
}

验证结果是否符合预期:

double[][] actual = multiplyMatrices(firstMatrix, secondMatrix);
assertThat(actual).isEqualTo(expected);

✅ 实现简单,适合教学或小规模数据。

❌ 性能一般,不适合大规模数据。

3.2. EJML (Efficient Java Matrix Library)

EJML 是一个专注于性能的 Java 矩阵库。

添加依赖:

<dependency>
    <groupId>org.ejml</groupId>
    <artifactId>ejml-all</artifactId>
    <version>0.38</version>
</dependency>

创建矩阵:

SimpleMatrix firstMatrix = new SimpleMatrix(new double[][] {
  {1d, 5d},
  {2d, 3d},
  {1d, 7d}
});

SimpleMatrix secondMatrix = new SimpleMatrix(new double[][] {
  {1d, 2d, 3d, 7d},
  {5d, 2d, 8d, 1d}
});

定义预期结果:

SimpleMatrix expected = new SimpleMatrix(new double[][] {
  {26d, 12d, 43d, 12d},
  {17d, 10d, 30d, 17d},
  {36d, 16d, 59d, 14d}
});

执行乘法并验证:

SimpleMatrix actual = firstMatrix.mult(secondMatrix);
assertThat(actual).matches(m -> m.isIdentical(expected, 0d));

✅ 性能优秀,适合中小型数据。

❌ API 略显繁琐。

3.3. ND4J (Numerical Computing for Java)

ND4J 是 deeplearning4j 生态中的数值计算库,支持矩阵运算。

添加依赖:

<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-native</artifactId>
    <version>1.0.0-beta4</version>
</dependency>

创建矩阵:

INDArray firstMatrix = Nd4j.create(new double[][] {
  {1d, 5d},
  {2d, 3d},
  {1d, 7d}
});

INDArray secondMatrix = Nd4j.create(new double[][] {
  {1d, 2d, 3d, 7d},
  {5d, 2d, 8d, 1d}
});

定义预期结果:

INDArray expected = Nd4j.create(new double[][] {
  {26d, 12d, 43d, 12d},
  {17d, 10d, 30d, 17d},
  {36d, 16d, 59d, 14d}
});

执行乘法并验证:

INDArray actual = firstMatrix.mmul(secondMatrix);
assertThat(actual).isEqualTo(expected);

✅ 大规模数据性能极佳,适合深度学习场景。

❌ 依赖较大,学习曲线陡峭。

3.4. Apache Commons Math3

Apache Commons Math3 是一个通用数学库,支持矩阵操作。

添加依赖:

<dependency>
    <groupId>org.apache.commons</groupId>
    <artifactId>commons-math3</artifactId>
    <version>3.6.1</version>
</dependency>

创建矩阵:

RealMatrix firstMatrix = new Array2DRowRealMatrix(new double[][] {
  {1d, 5d},
  {2d, 3d},
  {1d, 7d}
});

RealMatrix secondMatrix = new Array2DRowRealMatrix(new double[][] {
  {1d, 2d, 3d, 7d},
  {5d, 2d, 8d, 1d}
});

定义预期结果:

RealMatrix expected = new Array2DRowRealMatrix(new double[][] {
  {26d, 12d, 43d, 12d},
  {17d, 10d, 30d, 17d},
  {36d, 16d, 59d, 14d}
});

执行乘法并验证:

RealMatrix actual = firstMatrix.multiply(secondMatrix);
assertThat(actual).isEqualTo(expected);

✅ 稳定性好,适合通用场景。

❌ 性能中等,不适合大规模数据。

3.5. LA4J (Linear Algebra for Java)

LA4J 是一个专注于线性代数的库。

添加依赖:

<dependency>
    <groupId>org.la4j</groupId>
    <artifactId>la4j</artifactId>
    <version>0.6.0</version>
</dependency>

创建矩阵:

Matrix firstMatrix = new Basic2DMatrix(new double[][] {
  {1d, 5d},
  {2d, 3d},
  {1d, 7d}
});

Matrix secondMatrix = new Basic2DMatrix(new double[][] {
  {1d, 2d, 3d, 7d},
  {5d, 2d, 8d, 1d}
});

定义预期结果:

Matrix expected = new Basic2DMatrix(new double[][] {
  {26d, 12d, 43d, 12d},
  {17d, 10d, 30d, 17d},
  {36d, 16d, 59d, 14d}
});

执行乘法并验证:

Matrix actual = firstMatrix.multiply(secondMatrix);
assertThat(actual).isEqualTo(expected);

✅ 简洁易用,适合教学或中等规模数据。

❌ 性能不如 EJML 和 ND4J。

3.6. Colt

Colt 是由 CERN 开发的高性能科学计算库。

添加依赖:

<dependency>
    <groupId>colt</groupId>
    <artifactId>colt</artifactId>
    <version>1.2.0</version>
</dependency>

创建矩阵:

DoubleFactory2D doubleFactory2D = DoubleFactory2D.dense;
DoubleMatrix2D firstMatrix = doubleFactory2D.make(new double[][] {
  {1d, 5d},
  {2d, 3d},
  {1d, 7d}
});

DoubleMatrix2D secondMatrix = doubleFactory2D.make(new double[][] {
  {1d, 2d, 3d, 7d},
  {5d, 2d, 8d, 1d}
});

定义预期结果:

DoubleMatrix2D expected = doubleFactory2D.make(new double[][] {
  {26d, 12d, 43d, 12d},
  {17d, 10d, 30d, 17d},
  {36d, 16d, 59d, 14d}
});

执行乘法并验证:

Algebra algebra = new Algebra();
DoubleMatrix2D actual = algebra.mult(firstMatrix, secondMatrix);
assertThat(actual).isEqualTo(expected);

✅ 性能较好,适合科学计算。

❌ 文档较少,社区活跃度一般。

4. 性能测试

4.1. 小型矩阵(3×2 和 2×4)

耗时(μs/op)
自定义实现 0.389
EJML 0.226
Colt 0.219
LA4J 0.427
Apache Commons Math3 1.008
ND4J 12.670

✅ EJML 和 Colt 表现最佳。

❌ ND4J 在小型矩阵上表现不佳。

4.2. 大型矩阵(3000×3000)

耗时(秒/op)
自定义实现 497.493
Apache Commons Math3 511.140
Colt 197.914
LA4J 35.523
EJML 25.830
ND4J 0.548

✅ ND4J 在大型矩阵上遥遥领先。

❌ 自定义实现和 Apache Commons 性能极差。

5. 总结

场景 推荐库
教学/小规模数据 自定义实现、EJML、LA4J
科学计算 Colt
大规模/深度学习 ND4J
通用数学计算 Apache Commons Math3

⚠️ 选择库时需根据矩阵规模和用途综合考虑。

完整示例代码可在 GitHub 获取。


原始标题:Matrix Multiplication in Java