iceberg/
test_utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Test utilities.
19//! This module is pub just for internal testing.
20//! It is subject to change and is not intended to be used by external users.
21
22use std::sync::OnceLock;
23
24use arrow_array::RecordBatch;
25use expect_test::Expect;
26use itertools::Itertools;
27
28use crate::runtime::Runtime;
29
30/// Returns a process-wide [`Runtime`] suitable for tests that need to construct
31/// a [`Table`](crate::table::Table) outside a tokio context.
32///
33/// The returned [`Runtime`] wraps a single shared multi-thread tokio runtime
34/// that is lazily built on first call and lives until process exit. Cloning is
35/// cheap, so test code can call this every time it needs a runtime to feed
36/// into [`TableBuilder::runtime`](crate::table::TableBuilder::runtime).
37pub fn test_runtime() -> Runtime {
38    static TOKIO_RT: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
39    let tokio_rt = TOKIO_RT.get_or_init(|| {
40        tokio::runtime::Builder::new_multi_thread()
41            .enable_all()
42            .build()
43            .expect("failed to build test tokio runtime")
44    });
45    Runtime::new(tokio_rt)
46}
47
48/// Snapshot testing to check the resulting record batch.
49///
50/// - `expected_schema/data`: put `expect![[""]]` as a placeholder,
51///   and then run test with `UPDATE_EXPECT=1 cargo test` to automatically update the result,
52///   or use rust-analyzer (see [video](https://github.com/rust-analyzer/expect-test)).
53///   Check the doc of [`expect_test`] for more details.
54/// - `ignore_check_columns`: Some columns are not stable, so we can skip them.
55/// - `sort_column`: The order of the data might be non-deterministic, so we can sort it by a column.
56pub fn check_record_batches(
57    record_batches: Vec<RecordBatch>,
58    expected_schema: Expect,
59    expected_data: Expect,
60    ignore_check_columns: &[&str],
61    sort_column: Option<&str>,
62) {
63    assert!(!record_batches.is_empty(), "Empty record batches");
64
65    // Combine record batches using the first batch's schema
66    let first_batch = record_batches.first().unwrap();
67    let record_batch =
68        arrow_select::concat::concat_batches(&first_batch.schema(), &record_batches).unwrap();
69
70    let mut columns = record_batch.columns().to_vec();
71    if let Some(sort_column) = sort_column {
72        let column = record_batch.column_by_name(sort_column).unwrap();
73        let indices = arrow_ord::sort::sort_to_indices(column, None, None).unwrap();
74        columns = columns
75            .iter()
76            .map(|column| arrow_select::take::take(column.as_ref(), &indices, None).unwrap())
77            .collect_vec();
78    }
79
80    expected_schema.assert_eq(&format!(
81        "{}",
82        record_batch.schema().fields().iter().format(",\n")
83    ));
84    expected_data.assert_eq(&format!(
85        "{}",
86        record_batch
87            .schema()
88            .fields()
89            .iter()
90            .zip_eq(columns)
91            .map(|(field, column)| {
92                if ignore_check_columns.contains(&field.name().as_str()) {
93                    format!("{}: (skipped)", field.name())
94                } else {
95                    format!("{}: {:?}", field.name(), column)
96                }
97            })
98            .format(",\n")
99    ));
100}