iceberg/spec/schema/
id_reassigner.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::utils::try_insert_field;
19use super::*;
20
21pub struct ReassignFieldIds {
22    next_field_id: i32,
23    old_to_new_id: HashMap<i32, i32>,
24}
25
26// We are not using the visitor here, as post order traversal is not desired.
27// Instead we want to re-assign all fields on one level first before diving deeper.
28impl ReassignFieldIds {
29    pub fn new(start_from: i32) -> Self {
30        Self {
31            next_field_id: start_from,
32            old_to_new_id: HashMap::new(),
33        }
34    }
35
36    pub fn reassign_field_ids(
37        &mut self,
38        fields: Vec<NestedFieldRef>,
39    ) -> Result<Vec<NestedFieldRef>> {
40        // Visit fields on the same level first
41        let outer_fields = fields
42            .into_iter()
43            .map(|field| {
44                try_insert_field(&mut self.old_to_new_id, field.id, self.next_field_id)?;
45                let new_field = Arc::unwrap_or_clone(field).with_id(self.next_field_id);
46                self.increase_next_field_id()?;
47                Ok(Arc::new(new_field))
48            })
49            .collect::<Result<Vec<_>>>()?;
50
51        // Now visit nested fields
52        outer_fields
53            .into_iter()
54            .map(|field| {
55                if field.field_type.is_primitive() {
56                    Ok(field)
57                } else {
58                    let mut new_field = Arc::unwrap_or_clone(field);
59                    *new_field.field_type = self.reassign_ids_visit_type(*new_field.field_type)?;
60                    Ok(Arc::new(new_field))
61                }
62            })
63            .collect()
64    }
65
66    fn reassign_ids_visit_type(&mut self, field_type: Type) -> Result<Type> {
67        match field_type {
68            Type::Primitive(s) => Ok(Type::Primitive(s)),
69            Type::Struct(s) => {
70                let new_fields = self.reassign_field_ids(s.fields().to_vec())?;
71                Ok(Type::Struct(StructType::new(new_fields)))
72            }
73            Type::List(l) => {
74                self.old_to_new_id
75                    .insert(l.element_field.id, self.next_field_id);
76                let mut element_field = Arc::unwrap_or_clone(l.element_field);
77                element_field.id = self.next_field_id;
78                self.increase_next_field_id()?;
79                *element_field.field_type =
80                    self.reassign_ids_visit_type(*element_field.field_type)?;
81                Ok(Type::List(ListType {
82                    element_field: Arc::new(element_field),
83                }))
84            }
85            Type::Map(m) => {
86                self.old_to_new_id
87                    .insert(m.key_field.id, self.next_field_id);
88                let mut key_field = Arc::unwrap_or_clone(m.key_field);
89                key_field.id = self.next_field_id;
90                self.increase_next_field_id()?;
91                *key_field.field_type = self.reassign_ids_visit_type(*key_field.field_type)?;
92
93                self.old_to_new_id
94                    .insert(m.value_field.id, self.next_field_id);
95                let mut value_field = Arc::unwrap_or_clone(m.value_field);
96                value_field.id = self.next_field_id;
97                self.increase_next_field_id()?;
98                *value_field.field_type = self.reassign_ids_visit_type(*value_field.field_type)?;
99
100                Ok(Type::Map(MapType {
101                    key_field: Arc::new(key_field),
102                    value_field: Arc::new(value_field),
103                }))
104            }
105        }
106    }
107
108    fn increase_next_field_id(&mut self) -> Result<()> {
109        self.next_field_id = self.next_field_id.checked_add(1).ok_or_else(|| {
110            Error::new(
111                ErrorKind::DataInvalid,
112                "Field ID overflowed, cannot add more fields",
113            )
114        })?;
115        Ok(())
116    }
117
118    pub fn apply_to_identifier_fields(&self, field_ids: HashSet<i32>) -> Result<HashSet<i32>> {
119        field_ids
120            .into_iter()
121            .map(|id| {
122                self.old_to_new_id.get(&id).copied().ok_or_else(|| {
123                    Error::new(
124                        ErrorKind::DataInvalid,
125                        format!("Identifier Field ID {id} not found"),
126                    )
127                })
128            })
129            .collect()
130    }
131
132    pub fn apply_to_aliases(
133        &self,
134        alias: BiHashMap<String, i32>,
135    ) -> Result<BiHashMap<String, i32>> {
136        alias
137            .into_iter()
138            .map(|(name, id)| {
139                self.old_to_new_id
140                    .get(&id)
141                    .copied()
142                    .ok_or_else(|| {
143                        Error::new(
144                            ErrorKind::DataInvalid,
145                            format!("Field with id {id} for alias {name} not found"),
146                        )
147                    })
148                    .map(|new_id| (name, new_id))
149            })
150            .collect()
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use crate::spec::schema::tests::table_schema_nested;
158
159    #[test]
160    fn test_reassign_ids() {
161        let schema = Schema::builder()
162            .with_schema_id(1)
163            .with_identifier_field_ids(vec![3])
164            .with_alias(BiHashMap::from_iter(vec![("bar_alias".to_string(), 3)]))
165            .with_fields(vec![
166                NestedField::optional(5, "foo", Type::Primitive(PrimitiveType::String)).into(),
167                NestedField::required(3, "bar", Type::Primitive(PrimitiveType::Int)).into(),
168                NestedField::optional(4, "baz", Type::Primitive(PrimitiveType::Boolean)).into(),
169            ])
170            .build()
171            .unwrap();
172
173        let reassigned_schema = schema
174            .into_builder()
175            .with_reassigned_field_ids(0)
176            .build()
177            .unwrap();
178
179        let expected = Schema::builder()
180            .with_schema_id(1)
181            .with_identifier_field_ids(vec![1])
182            .with_alias(BiHashMap::from_iter(vec![("bar_alias".to_string(), 1)]))
183            .with_fields(vec![
184                NestedField::optional(0, "foo", Type::Primitive(PrimitiveType::String)).into(),
185                NestedField::required(1, "bar", Type::Primitive(PrimitiveType::Int)).into(),
186                NestedField::optional(2, "baz", Type::Primitive(PrimitiveType::Boolean)).into(),
187            ])
188            .build()
189            .unwrap();
190
191        pretty_assertions::assert_eq!(expected, reassigned_schema);
192        assert_eq!(reassigned_schema.highest_field_id(), 2);
193    }
194
195    #[test]
196    fn test_reassigned_ids_nested() {
197        let schema = table_schema_nested();
198        let reassigned_schema = schema
199            .into_builder()
200            .with_alias(BiHashMap::from_iter(vec![("bar_alias".to_string(), 2)]))
201            .with_reassigned_field_ids(0)
202            .build()
203            .unwrap();
204
205        let expected = Schema::builder()
206            .with_schema_id(1)
207            .with_identifier_field_ids(vec![1])
208            .with_alias(BiHashMap::from_iter(vec![("bar_alias".to_string(), 1)]))
209            .with_fields(vec![
210                NestedField::optional(0, "foo", Type::Primitive(PrimitiveType::String)).into(),
211                NestedField::required(1, "bar", Type::Primitive(PrimitiveType::Int)).into(),
212                NestedField::optional(2, "baz", Type::Primitive(PrimitiveType::Boolean)).into(),
213                NestedField::required(
214                    3,
215                    "qux",
216                    Type::List(ListType {
217                        element_field: NestedField::list_element(
218                            7,
219                            Type::Primitive(PrimitiveType::String),
220                            true,
221                        )
222                        .into(),
223                    }),
224                )
225                .into(),
226                NestedField::required(
227                    4,
228                    "quux",
229                    Type::Map(MapType {
230                        key_field: NestedField::map_key_element(
231                            8,
232                            Type::Primitive(PrimitiveType::String),
233                        )
234                        .into(),
235                        value_field: NestedField::map_value_element(
236                            9,
237                            Type::Map(MapType {
238                                key_field: NestedField::map_key_element(
239                                    10,
240                                    Type::Primitive(PrimitiveType::String),
241                                )
242                                .into(),
243                                value_field: NestedField::map_value_element(
244                                    11,
245                                    Type::Primitive(PrimitiveType::Int),
246                                    true,
247                                )
248                                .into(),
249                            }),
250                            true,
251                        )
252                        .into(),
253                    }),
254                )
255                .into(),
256                NestedField::required(
257                    5,
258                    "location",
259                    Type::List(ListType {
260                        element_field: NestedField::list_element(
261                            12,
262                            Type::Struct(StructType::new(vec![
263                                NestedField::optional(
264                                    13,
265                                    "latitude",
266                                    Type::Primitive(PrimitiveType::Float),
267                                )
268                                .into(),
269                                NestedField::optional(
270                                    14,
271                                    "longitude",
272                                    Type::Primitive(PrimitiveType::Float),
273                                )
274                                .into(),
275                            ])),
276                            true,
277                        )
278                        .into(),
279                    }),
280                )
281                .into(),
282                NestedField::optional(
283                    6,
284                    "person",
285                    Type::Struct(StructType::new(vec![
286                        NestedField::optional(15, "name", Type::Primitive(PrimitiveType::String))
287                            .into(),
288                        NestedField::required(16, "age", Type::Primitive(PrimitiveType::Int))
289                            .into(),
290                    ])),
291                )
292                .into(),
293            ])
294            .build()
295            .unwrap();
296
297        pretty_assertions::assert_eq!(expected, reassigned_schema);
298        assert_eq!(reassigned_schema.highest_field_id(), 16);
299        assert_eq!(reassigned_schema.field_by_id(6).unwrap().name, "person");
300        assert_eq!(reassigned_schema.field_by_id(16).unwrap().name, "age");
301    }
302
303    #[test]
304    fn test_reassign_ids_fails_with_duplicate_ids() {
305        let reassigned_schema = Schema::builder()
306            .with_schema_id(1)
307            .with_identifier_field_ids(vec![5])
308            .with_alias(BiHashMap::from_iter(vec![("bar_alias".to_string(), 3)]))
309            .with_fields(vec![
310                NestedField::required(5, "foo", Type::Primitive(PrimitiveType::String)).into(),
311                NestedField::optional(3, "bar", Type::Primitive(PrimitiveType::Int)).into(),
312                NestedField::optional(3, "baz", Type::Primitive(PrimitiveType::Boolean)).into(),
313            ])
314            .with_reassigned_field_ids(0)
315            .build()
316            .unwrap_err();
317
318        assert!(reassigned_schema.message().contains("'field.id' 3"));
319    }
320}