iceberg/expr/visitors/
inclusive_projection.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19
20use fnv::FnvHashSet;
21
22use crate::expr::visitors::bound_predicate_visitor::{BoundPredicateVisitor, visit};
23use crate::expr::{BoundPredicate, BoundReference, Predicate};
24use crate::spec::{Datum, PartitionField, PartitionSpecRef};
25use crate::{Error, ErrorKind};
26
27pub(crate) struct InclusiveProjection {
28    partition_spec: PartitionSpecRef,
29    cached_parts: HashMap<i32, Vec<PartitionField>>,
30}
31
32impl InclusiveProjection {
33    pub(crate) fn new(partition_spec: PartitionSpecRef) -> Self {
34        Self {
35            partition_spec,
36            cached_parts: HashMap::new(),
37        }
38    }
39
40    fn get_parts_for_field_id(&mut self, field_id: i32) -> &Vec<PartitionField> {
41        if let std::collections::hash_map::Entry::Vacant(e) = self.cached_parts.entry(field_id) {
42            let mut parts: Vec<PartitionField> = vec![];
43            for partition_spec_field in self.partition_spec.fields() {
44                if partition_spec_field.source_id == field_id {
45                    parts.push(partition_spec_field.clone())
46                }
47            }
48
49            e.insert(parts);
50        }
51
52        &self.cached_parts[&field_id]
53    }
54
55    pub(crate) fn project(&mut self, predicate: &BoundPredicate) -> crate::Result<Predicate> {
56        visit(self, predicate)
57    }
58
59    fn get_parts(
60        &mut self,
61        reference: &BoundReference,
62        predicate: &BoundPredicate,
63    ) -> Result<Predicate, Error> {
64        let field_id = reference.field().id;
65
66        // This could be made a bit neater if `try_reduce` ever becomes stable
67        self.get_parts_for_field_id(field_id)
68            .iter()
69            .try_fold(Predicate::AlwaysTrue, |res, part| {
70                Ok(
71                    if let Some(pred_for_part) = part.transform.project(&part.name, predicate)? {
72                        if res == Predicate::AlwaysTrue {
73                            pred_for_part
74                        } else {
75                            res.and(pred_for_part)
76                        }
77                    } else {
78                        res
79                    },
80                )
81            })
82    }
83}
84
85impl BoundPredicateVisitor for InclusiveProjection {
86    type T = Predicate;
87
88    fn always_true(&mut self) -> crate::Result<Self::T> {
89        Ok(Predicate::AlwaysTrue)
90    }
91
92    fn always_false(&mut self) -> crate::Result<Self::T> {
93        Ok(Predicate::AlwaysFalse)
94    }
95
96    fn and(&mut self, lhs: Self::T, rhs: Self::T) -> crate::Result<Self::T> {
97        Ok(lhs.and(rhs))
98    }
99
100    fn or(&mut self, lhs: Self::T, rhs: Self::T) -> crate::Result<Self::T> {
101        Ok(lhs.or(rhs))
102    }
103
104    fn not(&mut self, _inner: Self::T) -> crate::Result<Self::T> {
105        Err(Error::new(
106            ErrorKind::Unexpected,
107            "InclusiveProjection should not be performed against Predicates that contain a Not operator. Ensure that \"Rewrite Not\" gets applied to the originating Predicate before binding it.",
108        ))
109    }
110
111    fn is_null(
112        &mut self,
113        reference: &BoundReference,
114        predicate: &BoundPredicate,
115    ) -> crate::Result<Self::T> {
116        self.get_parts(reference, predicate)
117    }
118
119    fn not_null(
120        &mut self,
121        reference: &BoundReference,
122        predicate: &BoundPredicate,
123    ) -> crate::Result<Self::T> {
124        self.get_parts(reference, predicate)
125    }
126
127    fn is_nan(
128        &mut self,
129        reference: &BoundReference,
130        predicate: &BoundPredicate,
131    ) -> crate::Result<Self::T> {
132        self.get_parts(reference, predicate)
133    }
134
135    fn not_nan(
136        &mut self,
137        reference: &BoundReference,
138        predicate: &BoundPredicate,
139    ) -> crate::Result<Self::T> {
140        self.get_parts(reference, predicate)
141    }
142
143    fn less_than(
144        &mut self,
145        reference: &BoundReference,
146        _literal: &Datum,
147        predicate: &BoundPredicate,
148    ) -> crate::Result<Self::T> {
149        self.get_parts(reference, predicate)
150    }
151
152    fn less_than_or_eq(
153        &mut self,
154        reference: &BoundReference,
155        _literal: &Datum,
156        predicate: &BoundPredicate,
157    ) -> crate::Result<Self::T> {
158        self.get_parts(reference, predicate)
159    }
160
161    fn greater_than(
162        &mut self,
163        reference: &BoundReference,
164        _literal: &Datum,
165        predicate: &BoundPredicate,
166    ) -> crate::Result<Self::T> {
167        self.get_parts(reference, predicate)
168    }
169
170    fn greater_than_or_eq(
171        &mut self,
172        reference: &BoundReference,
173        _literal: &Datum,
174        predicate: &BoundPredicate,
175    ) -> crate::Result<Self::T> {
176        self.get_parts(reference, predicate)
177    }
178
179    fn eq(
180        &mut self,
181        reference: &BoundReference,
182        _literal: &Datum,
183        predicate: &BoundPredicate,
184    ) -> crate::Result<Self::T> {
185        self.get_parts(reference, predicate)
186    }
187
188    fn not_eq(
189        &mut self,
190        reference: &BoundReference,
191        _literal: &Datum,
192        predicate: &BoundPredicate,
193    ) -> crate::Result<Self::T> {
194        self.get_parts(reference, predicate)
195    }
196
197    fn starts_with(
198        &mut self,
199        reference: &BoundReference,
200        _literal: &Datum,
201        predicate: &BoundPredicate,
202    ) -> crate::Result<Self::T> {
203        self.get_parts(reference, predicate)
204    }
205
206    fn not_starts_with(
207        &mut self,
208        reference: &BoundReference,
209        _literal: &Datum,
210        predicate: &BoundPredicate,
211    ) -> crate::Result<Self::T> {
212        self.get_parts(reference, predicate)
213    }
214
215    fn r#in(
216        &mut self,
217        reference: &BoundReference,
218        _literals: &FnvHashSet<Datum>,
219        predicate: &BoundPredicate,
220    ) -> crate::Result<Self::T> {
221        self.get_parts(reference, predicate)
222    }
223
224    fn not_in(
225        &mut self,
226        reference: &BoundReference,
227        _literals: &FnvHashSet<Datum>,
228        predicate: &BoundPredicate,
229    ) -> crate::Result<Self::T> {
230        self.get_parts(reference, predicate)
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use std::sync::Arc;
237
238    use crate::expr::visitors::inclusive_projection::InclusiveProjection;
239    use crate::expr::{Bind, Predicate, Reference};
240    use crate::spec::{
241        Datum, NestedField, PartitionSpec, PrimitiveType, Schema, Transform, Type,
242        UnboundPartitionField,
243    };
244
245    fn build_test_schema() -> Schema {
246        Schema::builder()
247            .with_fields(vec![
248                Arc::new(NestedField::required(
249                    1,
250                    "a",
251                    Type::Primitive(PrimitiveType::Int),
252                )),
253                Arc::new(NestedField::required(
254                    2,
255                    "date",
256                    Type::Primitive(PrimitiveType::Date),
257                )),
258                Arc::new(NestedField::required(
259                    3,
260                    "name",
261                    Type::Primitive(PrimitiveType::String),
262                )),
263            ])
264            .build()
265            .unwrap()
266    }
267
268    #[test]
269    fn test_inclusive_projection_logic_ops() {
270        let schema = build_test_schema();
271        let arc_schema = Arc::new(schema);
272
273        let partition_spec = PartitionSpec::builder(arc_schema.clone())
274            .with_spec_id(1)
275            .build()
276            .unwrap();
277
278        let arc_partition_spec = Arc::new(partition_spec);
279
280        // this predicate contains only logic operators,
281        // AlwaysTrue, and AlwaysFalse.
282        let unbound_predicate = Predicate::AlwaysTrue
283            .and(Predicate::AlwaysFalse)
284            .or(Predicate::AlwaysTrue);
285
286        let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap();
287
288        // applying InclusiveProjection to bound_predicate
289        // should result in the same Predicate as the original
290        // `unbound_predicate`, since `InclusiveProjection`
291        // simply unbinds logic ops, AlwaysTrue, and AlwaysFalse.
292        let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec.clone());
293        let result = inclusive_projection.project(&bound_predicate).unwrap();
294
295        assert_eq!(result.to_string(), "TRUE".to_string())
296    }
297
298    #[test]
299    fn test_inclusive_projection_identity_transform() {
300        let schema = build_test_schema();
301        let arc_schema = Arc::new(schema);
302
303        let partition_spec = PartitionSpec::builder(arc_schema.clone())
304            .with_spec_id(1)
305            .add_unbound_field(
306                UnboundPartitionField::builder()
307                    .source_id(1)
308                    .name("a".to_string())
309                    .field_id(1)
310                    .transform(Transform::Identity)
311                    .build(),
312            )
313            .unwrap()
314            .build()
315            .unwrap();
316
317        let arc_partition_spec = Arc::new(partition_spec);
318
319        let unbound_predicate = Reference::new("a").less_than(Datum::int(10));
320
321        let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap();
322
323        // applying InclusiveProjection to bound_predicate
324        // should result in the same Predicate as the original
325        // `unbound_predicate`, since we have just a single partition field,
326        // and it has an Identity transform
327        let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec);
328        let result = inclusive_projection.project(&bound_predicate).unwrap();
329
330        let expected = "a < 10".to_string();
331
332        assert_eq!(result.to_string(), expected)
333    }
334
335    #[test]
336    fn test_inclusive_projection_date_year_transform() {
337        let schema = build_test_schema();
338        let arc_schema = Arc::new(schema);
339
340        let partition_spec = PartitionSpec::builder(arc_schema.clone())
341            .with_spec_id(1)
342            .add_unbound_fields(vec![UnboundPartitionField {
343                source_id: 2,
344                name: "year".to_string(),
345                field_id: Some(1000),
346                transform: Transform::Year,
347            }])
348            .unwrap()
349            .build()
350            .unwrap();
351
352        let arc_partition_spec = Arc::new(partition_spec);
353
354        let unbound_predicate =
355            Reference::new("date").less_than(Datum::date_from_str("2024-01-01").unwrap());
356
357        let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap();
358
359        // applying InclusiveProjection to bound_predicate
360        // should result in a predicate that correctly handles
361        // year, month and date
362        let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec);
363        let result = inclusive_projection.project(&bound_predicate).unwrap();
364
365        let expected = "year <= 53".to_string();
366
367        assert_eq!(result.to_string(), expected);
368    }
369
370    #[test]
371    fn test_inclusive_projection_date_month_transform() {
372        let schema = build_test_schema();
373        let arc_schema = Arc::new(schema);
374
375        let partition_spec = PartitionSpec::builder(arc_schema.clone())
376            .with_spec_id(1)
377            .add_unbound_fields(vec![UnboundPartitionField {
378                source_id: 2,
379                name: "month".to_string(),
380                field_id: Some(1000),
381                transform: Transform::Month,
382            }])
383            .unwrap()
384            .build()
385            .unwrap();
386
387        let arc_partition_spec = Arc::new(partition_spec);
388
389        let unbound_predicate =
390            Reference::new("date").less_than(Datum::date_from_str("2024-01-01").unwrap());
391
392        let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap();
393
394        // applying InclusiveProjection to bound_predicate
395        // should result in a predicate that correctly handles
396        // year, month and date
397        let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec);
398        let result = inclusive_projection.project(&bound_predicate).unwrap();
399
400        let expected = "month <= 647".to_string();
401
402        assert_eq!(result.to_string(), expected);
403    }
404
405    #[test]
406    fn test_inclusive_projection_date_day_transform() {
407        let schema = build_test_schema();
408        let arc_schema = Arc::new(schema);
409
410        let partition_spec = PartitionSpec::builder(arc_schema.clone())
411            .with_spec_id(1)
412            .add_unbound_fields(vec![UnboundPartitionField {
413                source_id: 2,
414                name: "day".to_string(),
415                field_id: Some(1000),
416                transform: Transform::Day,
417            }])
418            .unwrap()
419            .build()
420            .unwrap();
421
422        let arc_partition_spec = Arc::new(partition_spec);
423
424        let unbound_predicate =
425            Reference::new("date").less_than(Datum::date_from_str("2024-01-01").unwrap());
426
427        let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap();
428
429        // applying InclusiveProjection to bound_predicate
430        // should result in a predicate that correctly handles
431        // year, month and date
432        let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec);
433        let result = inclusive_projection.project(&bound_predicate).unwrap();
434
435        let expected = "day <= 2023-12-31".to_string();
436
437        assert_eq!(result.to_string(), expected);
438    }
439
440    #[test]
441    fn test_inclusive_projection_truncate_transform() {
442        let schema = build_test_schema();
443        let arc_schema = Arc::new(schema);
444
445        let partition_spec = PartitionSpec::builder(arc_schema.clone())
446            .with_spec_id(1)
447            .add_unbound_field(
448                UnboundPartitionField::builder()
449                    .source_id(3)
450                    .name("name_truncate".to_string())
451                    .field_id(3)
452                    .transform(Transform::Truncate(4))
453                    .build(),
454            )
455            .unwrap()
456            .build()
457            .unwrap();
458
459        let arc_partition_spec = Arc::new(partition_spec);
460
461        let unbound_predicate = Reference::new("name").starts_with(Datum::string("Testy McTest"));
462
463        let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap();
464
465        // applying InclusiveProjection to bound_predicate
466        // should result in the 'name STARTS WITH "Testy McTest"'
467        // predicate being transformed to 'name_truncate STARTS WITH "Test"',
468        // since a `Truncate(4)` partition will map values of
469        // name that start with "Testy McTest" into a partition
470        // for values of name that start with the first four letters
471        // of that, ie "Test".
472        let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec);
473        let result = inclusive_projection.project(&bound_predicate).unwrap();
474
475        let expected = "name_truncate STARTS WITH \"Test\"".to_string();
476
477        assert_eq!(result.to_string(), expected)
478    }
479
480    #[test]
481    fn test_inclusive_projection_bucket_transform() {
482        let schema = build_test_schema();
483        let arc_schema = Arc::new(schema);
484
485        let partition_spec = PartitionSpec::builder(arc_schema.clone())
486            .with_spec_id(1)
487            .add_unbound_field(
488                UnboundPartitionField::builder()
489                    .source_id(1)
490                    .name("a_bucket[7]".to_string())
491                    .field_id(1)
492                    .transform(Transform::Bucket(7))
493                    .build(),
494            )
495            .unwrap()
496            .build()
497            .unwrap();
498
499        let arc_partition_spec = Arc::new(partition_spec);
500
501        let unbound_predicate = Reference::new("a").equal_to(Datum::int(10));
502
503        let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap();
504
505        // applying InclusiveProjection to bound_predicate
506        // should result in the "a = 10" predicate being
507        // transformed into "a = 2", since 10 gets bucketed
508        // to 2 with a Bucket(7) partition
509        let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec);
510        let result = inclusive_projection.project(&bound_predicate).unwrap();
511
512        let expected = "a_bucket[7] = 2".to_string();
513
514        assert_eq!(result.to_string(), expected)
515    }
516}