1. 概述

Apache Spark 是一个开源的分布式分析处理系统,支持大规模数据工程和数据科学任务。它通过统一的数据传输、大规模转换和分布式 API,简化了分析应用的开发。

DataFrame 是 Spark API 的核心组件。本教程将通过客户数据案例,深入解析 Spark DataFrame 的关键操作。

2. Spark 中的 DataFrame

从逻辑上看,DataFrame 是按命名列组织的不可变记录集合。它类似于 RDBMS 中的表或 Java 中的 ResultSet。

作为 API,DataFrame 提供了对多个 Spark 库(包括 Spark SQL、Spark Streaming、MLib 和 GraphX)的统一访问入口。

在 Java 中,我们使用 Dataset<Row> 表示 DataFrame

本质上,Row 使用名为 Tungsten 的高效存储格式,相比其前身(如 RDD)大幅优化了 Spark 操作性能。

3. Maven 依赖

首先在 pom.xml 中添加 spark-corespark-sql 依赖:

<dependency>
    <groupId>org.apache.spark</groupId>
    <artifactId>spark-core_2.11</artifactId>
    <version>2.4.8</version>
</dependency>

<dependency>
    <groupId>org.apache.spark</groupId>
    <artifactId>spark-sql_2.11</artifactId>
    <version>2.4.8</version>
</dependency>

4. DataFrame 与 Schema

本质上,DataFrame 是带 Schema 的 RDD。Schema 可以自动推断,也可以显式定义为 StructType

StructType 是 Spark SQL 的内置数据类型,用于表示 StructField 对象集合

以下定义客户 Schema 的示例:

public static StructType minimumCustomerDataSchema() {
    return DataTypes.createStructType(new StructField[] {
      DataTypes.createStructField("id", DataTypes.StringType, true),
      DataTypes.createStructField("name", DataTypes.StringType, true),
      DataTypes.createStructField("gender", DataTypes.StringType, true),
      DataTypes.createStructField("transaction_amount", DataTypes.IntegerType, true) }
    );
}

每个 StructField 包含:

  • 列名(对应 DataFrame 列名)
  • 数据类型
  • 是否可为空(boolean 标志)

5. 构建 DataFrame

每个 Spark 应用的第一步是通过 Master 获取 SparkSession这是访问 DataFrame 的入口

public static SparkSession getSparkSession() {
    return SparkSession.builder()
      .appName("Customer Aggregation pipeline")
      .master("local")
      .getOrCreate();
}

注意:这里使用本地模式(local)。若连接集群,需替换为集群地址。

获得 SparkSession 后,可通过多种方式创建 DataFrame:

5.1. 从 List 创建

先构建 List<Customer>

List<Customer> customers = Arrays.asList(
  aCustomerWith("01", "jo", "Female", 2000), 
  aCustomerWith("02", "jack", "Male", 1200)
);

通过 createDataFrame 转换:

Dataset<Row> df = SPARK_SESSION
  .createDataFrame(customerList, Customer.class);

5.2. 从 Dataset 创建

若已有 Dataset,调用 toDF 即可转换:

Dataset<Customer> customerPOJODataSet = SPARK_SESSION
  .createDataset(CUSTOMERS, Encoders.bean(Customer.class));

Dataset<Row> df = customerPOJODataSet.toDF();

5.3. 使用 RowFactory 将 POJO 转为 Row

因 DataFrame 本质是 Dataset<Row>,需实现 MapFunction<Customer, Row>

public class CustomerToRowMapper implements MapFunction<Customer, Row> {
    
    @Override
    public Row call(Customer customer) throws Exception {
        Row row = RowFactory.create(
          customer.getId(),
          customer.getName().toUpperCase(),
          StringUtils.substring(customer.getGender(),0, 1),
          customer.getTransaction_amount()
        );
        return row;
    }
}

这里可对数据预处理(如性别截取首字母)。

5.4. 从 List 创建

List<Row> rows = customer.stream()
  .map(c -> new CustomerToRowMapper().call(c))
  .collect(Collectors.toList());

Dataset<Row> df = SparkDriver.getSparkSession()
  .createDataFrame(rows, SchemaFactory.minimumCustomerDataSchema());

Schema 会严格过滤列:未在 Schema 中定义的字段将被丢弃。

5.5. 从结构化文件和数据库创建

DataFrame API 对 CSV/JSON/数据库等格式保持一致

从多行 JSON 创建:

Dataset<Row> df = SparkDriver.getSparkSession()
  .read()
  .format("org.apache.spark.sql.execution.datasources.json.JsonFileFormat")
  .option("multiline", true)
  .load("data/minCustomerData.json");

从数据库读取:

Dataset<Row> df = SparkDriver.getSparkSession()
  .read()
  .option("url", "jdbc:postgresql://localhost:5432/customerdb")
  .option("dbtable", "customer")
  .option("user", "admin")
  .option("password", "securepass123")
  .option("serverTimezone", "EST")
  .format("jdbc")
  .load();

6. DataFrame 转 Dataset

当需操作 POJO 或使用 DataFrame 特有 API 时,需进行转换:

Dataset<Customer> ds = df.map(
  new CustomerMapper(),
  Encoders.bean(Customer.class)
);

CustomerMapper 实现 MapFunction<Row, Customer>

public class CustomerMapper implements MapFunction<Row, Customer> {

    @Override
    public Customer call(Row row) {
        Customer customer = new Customer();
        customer.setId(row.getAs("id"));
        customer.setName(row.getAs("name"));
        customer.setGender(row.getAs("gender"));
        customer.setTransaction_amount(Math.toIntExact(row.getAs("transaction_amount")));
        return customer;
    }
}

关键点MapFunction 实例只会创建一次,无论处理多少记录。

7. DataFrame 操作与转换

构建客户数据处理管道:从两个异构文件源读取数据 → 标准化 → 转换 → 写入数据库。最终按性别和来源统计年度消费。

7.1. 数据摄入

从 JSON 读取:

Dataset<Row> jsonDataToDF = SPARK_SESSION.read()
  .format("org.apache.spark.sql.execution.datasources.json.JsonFileFormat")
  .option("multiline", true)
  .load("data/customerData.json");

从 CSV 读取(需显式指定 Schema):

Dataset<Row> csvDataToDF = SPARK_SESSION.read()
  .format("csv")
  .option("header", "true")
  .schema(SchemaFactory.customerSchema())
  .option("dateFormat", "m/d/YYYY")
  .load("data/customerData.csv"); 

csvDataToDF.show(); 
csvDataToDF.printSchema(); 
return csvData;

使用 show() 查看数据,printSchema() 检查 Schema。两个源 Schema 不同,需标准化。

7.2. 标准化 DataFrame

标准化操作示例:

private Dataset<Row> normalizeCustomerDataFromEbay(Dataset<Row> rawDataset) {
    Dataset<Row> transformedDF = rawDataset
      .withColumn("id", concat(rawDataset.col("zoneId"),lit("-"), rawDataset.col("customerId")))
      .drop(column("customerId"))
      .withColumn("source", lit("ebay"))
      .withColumn("city", rawDataset.col("contact.customer_city"))
      .drop(column("contact"))
      .drop(column("zoneId"))
      .withColumn("year", functions.year(col("transaction_date")))
      .drop("transaction_date")
      .withColumn("firstName", functions.split(column("name"), " ")
        .getItem(0))
      .withColumn("lastName", functions.split(column("name"), " ")
        .getItem(1))
      .drop(column("name"));

    return transformedDF; 
}

关键操作:

  • concat + lit():拼接列生成新 ID
  • functions.year():提取年份
  • functions.split():分割姓名字段
  • drop():删除列
  • col():按名获取列引用
  • withColumnRenamed():重命名列

注意:DataFrame 是不可变的!每次修改都会生成新 DataFrame。

标准化后 Schema 统一为:

root
 |-- gender: string (nullable = true)
 |-- transaction_amount: long (nullable = true)
 |-- id: string (nullable = true)
 |-- source: string (nullable = false)
 |-- city: string (nullable = true)
 |-- year: integer (nullable = true)
 |-- firstName: string (nullable = true)
 |-- lastName: string (nullable = true)

7.3. 合并 DataFrame

Dataset<Row> combineDataframes(Dataset<Row> df1, Dataset<Row> df2) {
    return df1.unionByName(df2); 
}

踩坑提醒

  • unionByName():按列名合并(推荐)
  • union():按列位置合并(易出错)

7.4. 聚合 DataFrame

按年、来源、性别分组统计消费,并排序:

Dataset<Row> aggDF = dataset
  .groupBy(column("year"), column("source"), column("gender"))
  .sum("transactionAmount")
  .withColumnRenamed("sum(transaction_amount)", "yearly spent")
  .orderBy(col("year").asc(), col("yearly spent").desc());

关键操作:

  • groupBy():分组(类似 SQL GROUP BY)
  • sum():聚合计算
  • orderBy():排序
  • asc()/desc():指定升序/降序

输出示例:

+----+------+------+---------------+
|year|source|gender|annual_spending|
+----+------+------+---------------+
|2018|amazon|  Male|          10600|
|2018|amazon|Female|           6200|
|2018|  ebay|  Male|           5500|
|2021|  ebay|Female|          16000|
|2021|  ebay|  Male|          13500|
|2021|amazon|  Male|           4000|
|2021|amazon|Female|           2000|
+----+------+------+---------------+

最终 Schema:

root
 |-- source: string (nullable = false)
 |-- gender: string (nullable = true)
 |-- year: integer (nullable = true)
 |-- yearly spent: long (nullable = true)

7.5. 写入数据库

配置数据库连接:

Properties dbProps = new Properties();
dbProps.setProperty("connectionURL", "jdbc:postgresql://localhost:5432/customerdb");
dbProps.setProperty("driver", "org.postgresql.Driver");
dbProps.setProperty("user", "postgres");
dbProps.setProperty("password", "postgres");

执行写入:

String connectionURL = dbProperties.getProperty("connectionURL");

dataset.write()
  .mode(SaveMode.Overwrite)
  .jdbc(connectionURL, "customer", dbProperties);

写入模式

  • Overwrite:覆盖表(慎用!)
  • Append:追加数据
  • Ignore:表存在则跳过
  • ErrorIfExists:表存在则报错

8. 测试

使用 Docker 部署 PostgreSQL 和 pgAdmin 进行端到端测试:

@Test
void givenCSVAndJSON_whenRun_thenStoresAggregatedDataFrameInDB() throws Exception {
    Properties dbProps = new Properties();
    dbProps.setProperty("connectionURL", "jdbc:postgresql://localhost:5432/customerdb");
    dbProps.setProperty("driver", "org.postgresql.Driver");
    dbProps.setProperty("user", "postgres");
    dbProps.setProperty("password", "postgres");

    pipeline = new CustomerDataAggregationPipeline(dbProps);
    pipeline.run();

    String allCustomersSql = "Select count(*) from customer";

    Statement statement = conn.createStatement();
    ResultSet resultSet = statement.executeQuery(allCustomersSql);
    resultSet.next();
    int count = resultSet.getInt(1);
    assertEquals(7, count);
}

验证要点:

  1. customer 表自动创建
  2. 数据行数与预期一致
  3. 可通过 pgAdmin 查看结果

扩展write() 也支持导出为 CSV/JSON/Parquet 等格式。

9. 总结

本教程深入探讨了 Apache Spark DataFrame 的核心操作:

  1. 从多种源(JSON/CSV/数据库)创建 DataFrame
  2. 使用 Schema 约束数据结构
  3. 通过标准化/合并/聚合等操作处理数据
  4. 将结果持久化到数据库

完整代码示例可在 GitHub 获取。掌握这些操作,你就能简单粗暴地应对大多数数据处理场景了!


原始标题:Spark DataFrame