1use std::collections::HashMap;
19
20use fnv::FnvHashSet;
21
22use crate::expr::visitors::bound_predicate_visitor::{BoundPredicateVisitor, visit};
23use crate::expr::{BoundPredicate, BoundReference, Predicate};
24use crate::spec::{Datum, PartitionField, PartitionSpecRef};
25use crate::{Error, ErrorKind};
26
27pub(crate) struct InclusiveProjection {
28 partition_spec: PartitionSpecRef,
29 cached_parts: HashMap<i32, Vec<PartitionField>>,
30}
31
32impl InclusiveProjection {
33 pub(crate) fn new(partition_spec: PartitionSpecRef) -> Self {
34 Self {
35 partition_spec,
36 cached_parts: HashMap::new(),
37 }
38 }
39
40 fn get_parts_for_field_id(&mut self, field_id: i32) -> &Vec<PartitionField> {
41 if let std::collections::hash_map::Entry::Vacant(e) = self.cached_parts.entry(field_id) {
42 let mut parts: Vec<PartitionField> = vec![];
43 for partition_spec_field in self.partition_spec.fields() {
44 if partition_spec_field.source_id == field_id {
45 parts.push(partition_spec_field.clone())
46 }
47 }
48
49 e.insert(parts);
50 }
51
52 &self.cached_parts[&field_id]
53 }
54
55 pub(crate) fn project(&mut self, predicate: &BoundPredicate) -> crate::Result<Predicate> {
56 visit(self, predicate)
57 }
58
59 fn get_parts(
60 &mut self,
61 reference: &BoundReference,
62 predicate: &BoundPredicate,
63 ) -> Result<Predicate, Error> {
64 let field_id = reference.field().id;
65
66 self.get_parts_for_field_id(field_id)
68 .iter()
69 .try_fold(Predicate::AlwaysTrue, |res, part| {
70 Ok(
71 if let Some(pred_for_part) = part.transform.project(&part.name, predicate)? {
72 if res == Predicate::AlwaysTrue {
73 pred_for_part
74 } else {
75 res.and(pred_for_part)
76 }
77 } else {
78 res
79 },
80 )
81 })
82 }
83}
84
85impl BoundPredicateVisitor for InclusiveProjection {
86 type T = Predicate;
87
88 fn always_true(&mut self) -> crate::Result<Self::T> {
89 Ok(Predicate::AlwaysTrue)
90 }
91
92 fn always_false(&mut self) -> crate::Result<Self::T> {
93 Ok(Predicate::AlwaysFalse)
94 }
95
96 fn and(&mut self, lhs: Self::T, rhs: Self::T) -> crate::Result<Self::T> {
97 Ok(lhs.and(rhs))
98 }
99
100 fn or(&mut self, lhs: Self::T, rhs: Self::T) -> crate::Result<Self::T> {
101 Ok(lhs.or(rhs))
102 }
103
104 fn not(&mut self, _inner: Self::T) -> crate::Result<Self::T> {
105 Err(Error::new(
106 ErrorKind::Unexpected,
107 "InclusiveProjection should not be performed against Predicates that contain a Not operator. Ensure that \"Rewrite Not\" gets applied to the originating Predicate before binding it.",
108 ))
109 }
110
111 fn is_null(
112 &mut self,
113 reference: &BoundReference,
114 predicate: &BoundPredicate,
115 ) -> crate::Result<Self::T> {
116 self.get_parts(reference, predicate)
117 }
118
119 fn not_null(
120 &mut self,
121 reference: &BoundReference,
122 predicate: &BoundPredicate,
123 ) -> crate::Result<Self::T> {
124 self.get_parts(reference, predicate)
125 }
126
127 fn is_nan(
128 &mut self,
129 reference: &BoundReference,
130 predicate: &BoundPredicate,
131 ) -> crate::Result<Self::T> {
132 self.get_parts(reference, predicate)
133 }
134
135 fn not_nan(
136 &mut self,
137 reference: &BoundReference,
138 predicate: &BoundPredicate,
139 ) -> crate::Result<Self::T> {
140 self.get_parts(reference, predicate)
141 }
142
143 fn less_than(
144 &mut self,
145 reference: &BoundReference,
146 _literal: &Datum,
147 predicate: &BoundPredicate,
148 ) -> crate::Result<Self::T> {
149 self.get_parts(reference, predicate)
150 }
151
152 fn less_than_or_eq(
153 &mut self,
154 reference: &BoundReference,
155 _literal: &Datum,
156 predicate: &BoundPredicate,
157 ) -> crate::Result<Self::T> {
158 self.get_parts(reference, predicate)
159 }
160
161 fn greater_than(
162 &mut self,
163 reference: &BoundReference,
164 _literal: &Datum,
165 predicate: &BoundPredicate,
166 ) -> crate::Result<Self::T> {
167 self.get_parts(reference, predicate)
168 }
169
170 fn greater_than_or_eq(
171 &mut self,
172 reference: &BoundReference,
173 _literal: &Datum,
174 predicate: &BoundPredicate,
175 ) -> crate::Result<Self::T> {
176 self.get_parts(reference, predicate)
177 }
178
179 fn eq(
180 &mut self,
181 reference: &BoundReference,
182 _literal: &Datum,
183 predicate: &BoundPredicate,
184 ) -> crate::Result<Self::T> {
185 self.get_parts(reference, predicate)
186 }
187
188 fn not_eq(
189 &mut self,
190 reference: &BoundReference,
191 _literal: &Datum,
192 predicate: &BoundPredicate,
193 ) -> crate::Result<Self::T> {
194 self.get_parts(reference, predicate)
195 }
196
197 fn starts_with(
198 &mut self,
199 reference: &BoundReference,
200 _literal: &Datum,
201 predicate: &BoundPredicate,
202 ) -> crate::Result<Self::T> {
203 self.get_parts(reference, predicate)
204 }
205
206 fn not_starts_with(
207 &mut self,
208 reference: &BoundReference,
209 _literal: &Datum,
210 predicate: &BoundPredicate,
211 ) -> crate::Result<Self::T> {
212 self.get_parts(reference, predicate)
213 }
214
215 fn r#in(
216 &mut self,
217 reference: &BoundReference,
218 _literals: &FnvHashSet<Datum>,
219 predicate: &BoundPredicate,
220 ) -> crate::Result<Self::T> {
221 self.get_parts(reference, predicate)
222 }
223
224 fn not_in(
225 &mut self,
226 reference: &BoundReference,
227 _literals: &FnvHashSet<Datum>,
228 predicate: &BoundPredicate,
229 ) -> crate::Result<Self::T> {
230 self.get_parts(reference, predicate)
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use std::sync::Arc;
237
238 use crate::expr::visitors::inclusive_projection::InclusiveProjection;
239 use crate::expr::{Bind, Predicate, Reference};
240 use crate::spec::{
241 Datum, NestedField, PartitionSpec, PrimitiveType, Schema, Transform, Type,
242 UnboundPartitionField,
243 };
244
245 fn build_test_schema() -> Schema {
246 Schema::builder()
247 .with_fields(vec![
248 Arc::new(NestedField::required(
249 1,
250 "a",
251 Type::Primitive(PrimitiveType::Int),
252 )),
253 Arc::new(NestedField::required(
254 2,
255 "date",
256 Type::Primitive(PrimitiveType::Date),
257 )),
258 Arc::new(NestedField::required(
259 3,
260 "name",
261 Type::Primitive(PrimitiveType::String),
262 )),
263 ])
264 .build()
265 .unwrap()
266 }
267
268 #[test]
269 fn test_inclusive_projection_logic_ops() {
270 let schema = build_test_schema();
271 let arc_schema = Arc::new(schema);
272
273 let partition_spec = PartitionSpec::builder(arc_schema.clone())
274 .with_spec_id(1)
275 .build()
276 .unwrap();
277
278 let arc_partition_spec = Arc::new(partition_spec);
279
280 let unbound_predicate = Predicate::AlwaysTrue
283 .and(Predicate::AlwaysFalse)
284 .or(Predicate::AlwaysTrue);
285
286 let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap();
287
288 let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec.clone());
293 let result = inclusive_projection.project(&bound_predicate).unwrap();
294
295 assert_eq!(result.to_string(), "TRUE".to_string())
296 }
297
298 #[test]
299 fn test_inclusive_projection_identity_transform() {
300 let schema = build_test_schema();
301 let arc_schema = Arc::new(schema);
302
303 let partition_spec = PartitionSpec::builder(arc_schema.clone())
304 .with_spec_id(1)
305 .add_unbound_field(
306 UnboundPartitionField::builder()
307 .source_id(1)
308 .name("a".to_string())
309 .field_id(1)
310 .transform(Transform::Identity)
311 .build(),
312 )
313 .unwrap()
314 .build()
315 .unwrap();
316
317 let arc_partition_spec = Arc::new(partition_spec);
318
319 let unbound_predicate = Reference::new("a").less_than(Datum::int(10));
320
321 let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap();
322
323 let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec);
328 let result = inclusive_projection.project(&bound_predicate).unwrap();
329
330 let expected = "a < 10".to_string();
331
332 assert_eq!(result.to_string(), expected)
333 }
334
335 #[test]
336 fn test_inclusive_projection_date_year_transform() {
337 let schema = build_test_schema();
338 let arc_schema = Arc::new(schema);
339
340 let partition_spec = PartitionSpec::builder(arc_schema.clone())
341 .with_spec_id(1)
342 .add_unbound_fields(vec![UnboundPartitionField {
343 source_id: 2,
344 name: "year".to_string(),
345 field_id: Some(1000),
346 transform: Transform::Year,
347 }])
348 .unwrap()
349 .build()
350 .unwrap();
351
352 let arc_partition_spec = Arc::new(partition_spec);
353
354 let unbound_predicate =
355 Reference::new("date").less_than(Datum::date_from_str("2024-01-01").unwrap());
356
357 let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap();
358
359 let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec);
363 let result = inclusive_projection.project(&bound_predicate).unwrap();
364
365 let expected = "year <= 53".to_string();
366
367 assert_eq!(result.to_string(), expected);
368 }
369
370 #[test]
371 fn test_inclusive_projection_date_month_transform() {
372 let schema = build_test_schema();
373 let arc_schema = Arc::new(schema);
374
375 let partition_spec = PartitionSpec::builder(arc_schema.clone())
376 .with_spec_id(1)
377 .add_unbound_fields(vec![UnboundPartitionField {
378 source_id: 2,
379 name: "month".to_string(),
380 field_id: Some(1000),
381 transform: Transform::Month,
382 }])
383 .unwrap()
384 .build()
385 .unwrap();
386
387 let arc_partition_spec = Arc::new(partition_spec);
388
389 let unbound_predicate =
390 Reference::new("date").less_than(Datum::date_from_str("2024-01-01").unwrap());
391
392 let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap();
393
394 let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec);
398 let result = inclusive_projection.project(&bound_predicate).unwrap();
399
400 let expected = "month <= 647".to_string();
401
402 assert_eq!(result.to_string(), expected);
403 }
404
405 #[test]
406 fn test_inclusive_projection_date_day_transform() {
407 let schema = build_test_schema();
408 let arc_schema = Arc::new(schema);
409
410 let partition_spec = PartitionSpec::builder(arc_schema.clone())
411 .with_spec_id(1)
412 .add_unbound_fields(vec![UnboundPartitionField {
413 source_id: 2,
414 name: "day".to_string(),
415 field_id: Some(1000),
416 transform: Transform::Day,
417 }])
418 .unwrap()
419 .build()
420 .unwrap();
421
422 let arc_partition_spec = Arc::new(partition_spec);
423
424 let unbound_predicate =
425 Reference::new("date").less_than(Datum::date_from_str("2024-01-01").unwrap());
426
427 let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap();
428
429 let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec);
433 let result = inclusive_projection.project(&bound_predicate).unwrap();
434
435 let expected = "day <= 2023-12-31".to_string();
436
437 assert_eq!(result.to_string(), expected);
438 }
439
440 #[test]
441 fn test_inclusive_projection_truncate_transform() {
442 let schema = build_test_schema();
443 let arc_schema = Arc::new(schema);
444
445 let partition_spec = PartitionSpec::builder(arc_schema.clone())
446 .with_spec_id(1)
447 .add_unbound_field(
448 UnboundPartitionField::builder()
449 .source_id(3)
450 .name("name_truncate".to_string())
451 .field_id(3)
452 .transform(Transform::Truncate(4))
453 .build(),
454 )
455 .unwrap()
456 .build()
457 .unwrap();
458
459 let arc_partition_spec = Arc::new(partition_spec);
460
461 let unbound_predicate = Reference::new("name").starts_with(Datum::string("Testy McTest"));
462
463 let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap();
464
465 let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec);
473 let result = inclusive_projection.project(&bound_predicate).unwrap();
474
475 let expected = "name_truncate STARTS WITH \"Test\"".to_string();
476
477 assert_eq!(result.to_string(), expected)
478 }
479
480 #[test]
481 fn test_inclusive_projection_bucket_transform() {
482 let schema = build_test_schema();
483 let arc_schema = Arc::new(schema);
484
485 let partition_spec = PartitionSpec::builder(arc_schema.clone())
486 .with_spec_id(1)
487 .add_unbound_field(
488 UnboundPartitionField::builder()
489 .source_id(1)
490 .name("a_bucket[7]".to_string())
491 .field_id(1)
492 .transform(Transform::Bucket(7))
493 .build(),
494 )
495 .unwrap()
496 .build()
497 .unwrap();
498
499 let arc_partition_spec = Arc::new(partition_spec);
500
501 let unbound_predicate = Reference::new("a").equal_to(Datum::int(10));
502
503 let bound_predicate = unbound_predicate.bind(arc_schema.clone(), false).unwrap();
504
505 let mut inclusive_projection = InclusiveProjection::new(arc_partition_spec);
510 let result = inclusive_projection.project(&bound_predicate).unwrap();
511
512 let expected = "a_bucket[7] = 2".to_string();
513
514 assert_eq!(result.to_string(), expected)
515 }
516}