iceberg/expr/
predicate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This module contains predicate expressions.
19//! Predicate expressions are used to filter data, and evaluates to a boolean value. For example,
20//! `a > 10` is a predicate expression, and it evaluates to `true` if `a` is greater than `10`,
21
22use std::fmt::{Debug, Display, Formatter};
23use std::ops::Not;
24
25use array_init::array_init;
26use fnv::FnvHashSet;
27use itertools::Itertools;
28use serde::{Deserialize, Serialize};
29
30use crate::error::Result;
31use crate::expr::visitors::bound_predicate_visitor::visit as visit_bound;
32use crate::expr::visitors::predicate_visitor::visit;
33use crate::expr::visitors::rewrite_not::RewriteNotVisitor;
34use crate::expr::{Bind, BoundReference, PredicateOperator, Reference};
35use crate::spec::{Datum, PrimitiveLiteral, SchemaRef};
36use crate::{Error, ErrorKind};
37
38/// Logical expression, such as `AND`, `OR`, `NOT`.
39#[derive(PartialEq, Clone)]
40pub struct LogicalExpression<T, const N: usize> {
41    inputs: [Box<T>; N],
42}
43
44impl<T: Serialize, const N: usize> Serialize for LogicalExpression<T, N> {
45    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
46    where S: serde::Serializer {
47        self.inputs.serialize(serializer)
48    }
49}
50
51impl<'de, T: Deserialize<'de>, const N: usize> Deserialize<'de> for LogicalExpression<T, N> {
52    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
53    where D: serde::Deserializer<'de> {
54        let inputs = Vec::<Box<T>>::deserialize(deserializer)?;
55        Ok(LogicalExpression::new(
56            array_init::from_iter(inputs.into_iter()).ok_or_else(|| {
57                serde::de::Error::custom(format!("Failed to deserialize LogicalExpression: the len of inputs is not match with the len of LogicalExpression {N}"))
58            })?,
59        ))
60    }
61}
62
63impl<T: Debug, const N: usize> Debug for LogicalExpression<T, N> {
64    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
65        f.debug_struct("LogicalExpression")
66            .field("inputs", &self.inputs)
67            .finish()
68    }
69}
70
71impl<T, const N: usize> LogicalExpression<T, N> {
72    fn new(inputs: [Box<T>; N]) -> Self {
73        Self { inputs }
74    }
75
76    /// Return inputs of this logical expression.
77    pub fn inputs(&self) -> [&T; N] {
78        let mut ret: [&T; N] = [self.inputs[0].as_ref(); N];
79        for (i, item) in ret.iter_mut().enumerate() {
80            *item = &self.inputs[i];
81        }
82        ret
83    }
84}
85
86impl<T: Bind, const N: usize> Bind for LogicalExpression<T, N>
87where T::Bound: Sized
88{
89    type Bound = LogicalExpression<T::Bound, N>;
90
91    fn bind(&self, schema: SchemaRef, case_sensitive: bool) -> Result<Self::Bound> {
92        let mut outputs: [Option<Box<T::Bound>>; N] = array_init(|_| None);
93        for (i, input) in self.inputs.iter().enumerate() {
94            outputs[i] = Some(Box::new(input.bind(schema.clone(), case_sensitive)?));
95        }
96
97        // It's safe to use `unwrap` here since they are all `Some`.
98        let bound_inputs = array_init::from_iter(outputs.into_iter().map(Option::unwrap)).unwrap();
99        Ok(LogicalExpression::new(bound_inputs))
100    }
101}
102
103/// Unary predicate, for example, `a IS NULL`.
104#[derive(PartialEq, Clone, Serialize, Deserialize)]
105pub struct UnaryExpression<T> {
106    /// Operator of this predicate, must be single operand operator.
107    op: PredicateOperator,
108    /// Term of this predicate, for example, `a` in `a IS NULL`.
109    #[serde(bound(serialize = "T: Serialize", deserialize = "T: Deserialize<'de>"))]
110    term: T,
111}
112
113impl<T: Debug> Debug for UnaryExpression<T> {
114    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("UnaryExpression")
116            .field("op", &self.op)
117            .field("term", &self.term)
118            .finish()
119    }
120}
121
122impl<T: Display> Display for UnaryExpression<T> {
123    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
124        write!(f, "{} {}", self.term, self.op)
125    }
126}
127
128impl<T: Bind> Bind for UnaryExpression<T> {
129    type Bound = UnaryExpression<T::Bound>;
130
131    fn bind(&self, schema: SchemaRef, case_sensitive: bool) -> Result<Self::Bound> {
132        let bound_term = self.term.bind(schema, case_sensitive)?;
133        Ok(UnaryExpression::new(self.op, bound_term))
134    }
135}
136
137impl<T> UnaryExpression<T> {
138    /// Creates a unary expression with the given operator and term.
139    ///
140    /// # Example
141    ///
142    /// ```rust
143    /// use iceberg::expr::{PredicateOperator, Reference, UnaryExpression};
144    ///
145    /// UnaryExpression::new(PredicateOperator::IsNull, Reference::new("c"));
146    /// ```
147    pub fn new(op: PredicateOperator, term: T) -> Self {
148        debug_assert!(op.is_unary());
149        Self { op, term }
150    }
151
152    /// Return the operator of this predicate.
153    pub fn op(&self) -> PredicateOperator {
154        self.op
155    }
156
157    /// Return the term of this predicate.
158    pub fn term(&self) -> &T {
159        &self.term
160    }
161}
162
163/// Binary predicate, for example, `a > 10`.
164#[derive(PartialEq, Clone, Serialize, Deserialize)]
165pub struct BinaryExpression<T> {
166    /// Operator of this predicate, must be binary operator, such as `=`, `>`, `<`, etc.
167    op: PredicateOperator,
168    /// Term of this predicate, for example, `a` in `a > 10`.
169    #[serde(bound(serialize = "T: Serialize", deserialize = "T: Deserialize<'de>"))]
170    term: T,
171    /// Literal of this predicate, for example, `10` in `a > 10`.
172    literal: Datum,
173}
174
175impl<T: Debug> Debug for BinaryExpression<T> {
176    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
177        f.debug_struct("BinaryExpression")
178            .field("op", &self.op)
179            .field("term", &self.term)
180            .field("literal", &self.literal)
181            .finish()
182    }
183}
184
185impl<T> BinaryExpression<T> {
186    /// Creates a binary expression with the given operator, term and literal.
187    ///
188    /// # Example
189    ///
190    /// ```rust
191    /// use iceberg::expr::{BinaryExpression, PredicateOperator, Reference};
192    /// use iceberg::spec::Datum;
193    ///
194    /// BinaryExpression::new(
195    ///     PredicateOperator::LessThanOrEq,
196    ///     Reference::new("a"),
197    ///     Datum::int(10),
198    /// );
199    /// ```
200    pub fn new(op: PredicateOperator, term: T, literal: Datum) -> Self {
201        debug_assert!(op.is_binary());
202        Self { op, term, literal }
203    }
204
205    /// Return the operator used by this predicate expression.
206    pub fn op(&self) -> PredicateOperator {
207        self.op
208    }
209
210    /// Return the literal of this predicate.
211    pub fn literal(&self) -> &Datum {
212        &self.literal
213    }
214
215    /// Return the term of this predicate.
216    pub fn term(&self) -> &T {
217        &self.term
218    }
219}
220
221impl<T: Display> Display for BinaryExpression<T> {
222    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
223        write!(f, "{} {} {}", self.term, self.op, self.literal)
224    }
225}
226
227impl<T: Bind> Bind for BinaryExpression<T> {
228    type Bound = BinaryExpression<T::Bound>;
229
230    fn bind(&self, schema: SchemaRef, case_sensitive: bool) -> Result<Self::Bound> {
231        let bound_term = self.term.bind(schema.clone(), case_sensitive)?;
232        Ok(BinaryExpression::new(
233            self.op,
234            bound_term,
235            self.literal.clone(),
236        ))
237    }
238}
239
240/// Set predicates, for example, `a in (1, 2, 3)`.
241#[derive(PartialEq, Clone, Serialize, Deserialize)]
242pub struct SetExpression<T> {
243    /// Operator of this predicate, must be set operator, such as `IN`, `NOT IN`, etc.
244    op: PredicateOperator,
245    /// Term of this predicate, for example, `a` in `a in (1, 2, 3)`.
246    term: T,
247    /// Literals of this predicate, for example, `(1, 2, 3)` in `a in (1, 2, 3)`.
248    literals: FnvHashSet<Datum>,
249}
250
251impl<T: Debug> Debug for SetExpression<T> {
252    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
253        f.debug_struct("SetExpression")
254            .field("op", &self.op)
255            .field("term", &self.term)
256            .field("literal", &self.literals)
257            .finish()
258    }
259}
260
261impl<T> SetExpression<T> {
262    /// Creates a set expression with the given operator, term and literal.
263    ///
264    /// # Example
265    ///
266    /// ```rust
267    /// use fnv::FnvHashSet;
268    /// use iceberg::expr::{PredicateOperator, Reference, SetExpression};
269    /// use iceberg::spec::Datum;
270    ///
271    /// SetExpression::new(
272    ///     PredicateOperator::In,
273    ///     Reference::new("a"),
274    ///     FnvHashSet::from_iter(vec![Datum::int(1)]),
275    /// );
276    /// ```
277    pub fn new(op: PredicateOperator, term: T, literals: FnvHashSet<Datum>) -> Self {
278        debug_assert!(op.is_set());
279        Self { op, term, literals }
280    }
281
282    /// Return the operator of this predicate.
283    pub fn op(&self) -> PredicateOperator {
284        self.op
285    }
286
287    /// Return the hash set of values compared against the term in this expression.
288    pub fn literals(&self) -> &FnvHashSet<Datum> {
289        &self.literals
290    }
291
292    /// Return the term of this predicate.
293    pub fn term(&self) -> &T {
294        &self.term
295    }
296}
297
298impl<T: Bind> Bind for SetExpression<T> {
299    type Bound = SetExpression<T::Bound>;
300
301    fn bind(&self, schema: SchemaRef, case_sensitive: bool) -> Result<Self::Bound> {
302        let bound_term = self.term.bind(schema.clone(), case_sensitive)?;
303        Ok(SetExpression::new(
304            self.op,
305            bound_term,
306            self.literals.clone(),
307        ))
308    }
309}
310
311impl<T: Display + Debug> Display for SetExpression<T> {
312    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
313        let mut literal_strs = self.literals.iter().map(|l| format!("{l}"));
314
315        write!(f, "{} {} ({})", self.term, self.op, literal_strs.join(", "))
316    }
317}
318
319/// Unbound predicate expression before binding to a schema.
320#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
321pub enum Predicate {
322    /// AlwaysTrue predicate, for example, `TRUE`.
323    AlwaysTrue,
324    /// AlwaysFalse predicate, for example, `FALSE`.
325    AlwaysFalse,
326    /// And predicate, for example, `a > 10 AND b < 20`.
327    And(LogicalExpression<Predicate, 2>),
328    /// Or predicate, for example, `a > 10 OR b < 20`.
329    Or(LogicalExpression<Predicate, 2>),
330    /// Not predicate, for example, `NOT (a > 10)`.
331    Not(LogicalExpression<Predicate, 1>),
332    /// Unary expression, for example, `a IS NULL`.
333    Unary(UnaryExpression<Reference>),
334    /// Binary expression, for example, `a > 10`.
335    Binary(BinaryExpression<Reference>),
336    /// Set predicates, for example, `a in (1, 2, 3)`.
337    Set(SetExpression<Reference>),
338}
339
340impl Bind for Predicate {
341    type Bound = BoundPredicate;
342
343    fn bind(&self, schema: SchemaRef, case_sensitive: bool) -> Result<BoundPredicate> {
344        match self {
345            Predicate::And(expr) => {
346                let bound_expr = expr.bind(schema, case_sensitive)?;
347
348                let [left, right] = bound_expr.inputs;
349                Ok(match (left, right) {
350                    (_, r) if matches!(&*r, &BoundPredicate::AlwaysFalse) => {
351                        BoundPredicate::AlwaysFalse
352                    }
353                    (l, _) if matches!(&*l, &BoundPredicate::AlwaysFalse) => {
354                        BoundPredicate::AlwaysFalse
355                    }
356                    (left, r) if matches!(&*r, &BoundPredicate::AlwaysTrue) => *left,
357                    (l, right) if matches!(&*l, &BoundPredicate::AlwaysTrue) => *right,
358                    (left, right) => BoundPredicate::And(LogicalExpression::new([left, right])),
359                })
360            }
361            Predicate::Not(expr) => {
362                let bound_expr = expr.bind(schema, case_sensitive)?;
363                let [inner] = bound_expr.inputs;
364                Ok(match inner {
365                    e if matches!(&*e, &BoundPredicate::AlwaysTrue) => BoundPredicate::AlwaysFalse,
366                    e if matches!(&*e, &BoundPredicate::AlwaysFalse) => BoundPredicate::AlwaysTrue,
367                    e => BoundPredicate::Not(LogicalExpression::new([e])),
368                })
369            }
370            Predicate::Or(expr) => {
371                let bound_expr = expr.bind(schema, case_sensitive)?;
372                let [left, right] = bound_expr.inputs;
373                Ok(match (left, right) {
374                    (l, r)
375                        if matches!(&*r, &BoundPredicate::AlwaysTrue)
376                            || matches!(&*l, &BoundPredicate::AlwaysTrue) =>
377                    {
378                        BoundPredicate::AlwaysTrue
379                    }
380                    (left, r) if matches!(&*r, &BoundPredicate::AlwaysFalse) => *left,
381                    (l, right) if matches!(&*l, &BoundPredicate::AlwaysFalse) => *right,
382                    (left, right) => BoundPredicate::Or(LogicalExpression::new([left, right])),
383                })
384            }
385            Predicate::Unary(expr) => {
386                let bound_expr = expr.bind(schema, case_sensitive)?;
387
388                match &bound_expr.op {
389                    &PredicateOperator::IsNull => {
390                        if bound_expr.term.field().required {
391                            return Ok(BoundPredicate::AlwaysFalse);
392                        }
393                    }
394                    &PredicateOperator::NotNull => {
395                        if bound_expr.term.field().required {
396                            return Ok(BoundPredicate::AlwaysTrue);
397                        }
398                    }
399                    &PredicateOperator::IsNan | &PredicateOperator::NotNan => {
400                        if !bound_expr.term.field().field_type.is_floating_type() {
401                            return Err(Error::new(
402                                ErrorKind::DataInvalid,
403                                format!(
404                                    "Expecting floating point type, but found {}",
405                                    bound_expr.term.field().field_type
406                                ),
407                            ));
408                        }
409                    }
410                    op => {
411                        return Err(Error::new(
412                            ErrorKind::Unexpected,
413                            format!("Expecting unary operator, but found {op}"),
414                        ));
415                    }
416                }
417
418                Ok(BoundPredicate::Unary(bound_expr))
419            }
420            Predicate::Binary(expr) => {
421                let bound_expr = expr.bind(schema, case_sensitive)?;
422                let bound_literal = bound_expr.literal.to(&bound_expr.term.field().field_type)?;
423
424                match bound_literal.literal() {
425                    PrimitiveLiteral::AboveMax => match &bound_expr.op {
426                        &PredicateOperator::LessThan
427                        | &PredicateOperator::LessThanOrEq
428                        | &PredicateOperator::NotEq => {
429                            return Ok(BoundPredicate::AlwaysTrue);
430                        }
431                        &PredicateOperator::GreaterThan
432                        | &PredicateOperator::GreaterThanOrEq
433                        | &PredicateOperator::Eq => {
434                            return Ok(BoundPredicate::AlwaysFalse);
435                        }
436                        _ => {}
437                    },
438                    PrimitiveLiteral::BelowMin => match &bound_expr.op {
439                        &PredicateOperator::GreaterThan
440                        | &PredicateOperator::GreaterThanOrEq
441                        | &PredicateOperator::NotEq => {
442                            return Ok(BoundPredicate::AlwaysTrue);
443                        }
444                        &PredicateOperator::LessThan
445                        | &PredicateOperator::LessThanOrEq
446                        | &PredicateOperator::Eq => {
447                            return Ok(BoundPredicate::AlwaysFalse);
448                        }
449                        _ => {}
450                    },
451                    _ => {}
452                }
453
454                Ok(BoundPredicate::Binary(BinaryExpression::new(
455                    bound_expr.op,
456                    bound_expr.term,
457                    bound_literal,
458                )))
459            }
460            Predicate::Set(expr) => {
461                let bound_expr = expr.bind(schema, case_sensitive)?;
462                let bound_literals = bound_expr
463                    .literals
464                    .into_iter()
465                    .map(|l| l.to(&bound_expr.term.field().field_type))
466                    .collect::<Result<FnvHashSet<Datum>>>()?;
467
468                match &bound_expr.op {
469                    &PredicateOperator::In => {
470                        if bound_literals.is_empty() {
471                            return Ok(BoundPredicate::AlwaysFalse);
472                        }
473                        if bound_literals.len() == 1 {
474                            return Ok(BoundPredicate::Binary(BinaryExpression::new(
475                                PredicateOperator::Eq,
476                                bound_expr.term,
477                                bound_literals.into_iter().next().unwrap(),
478                            )));
479                        }
480                    }
481                    &PredicateOperator::NotIn => {
482                        if bound_literals.is_empty() {
483                            return Ok(BoundPredicate::AlwaysTrue);
484                        }
485                        if bound_literals.len() == 1 {
486                            return Ok(BoundPredicate::Binary(BinaryExpression::new(
487                                PredicateOperator::NotEq,
488                                bound_expr.term,
489                                bound_literals.into_iter().next().unwrap(),
490                            )));
491                        }
492                    }
493                    op => {
494                        return Err(Error::new(
495                            ErrorKind::Unexpected,
496                            format!("Expecting unary operator,but found {op}"),
497                        ));
498                    }
499                }
500
501                Ok(BoundPredicate::Set(SetExpression::new(
502                    bound_expr.op,
503                    bound_expr.term,
504                    bound_literals,
505                )))
506            }
507            Predicate::AlwaysTrue => Ok(BoundPredicate::AlwaysTrue),
508            Predicate::AlwaysFalse => Ok(BoundPredicate::AlwaysFalse),
509        }
510    }
511}
512
513impl Display for Predicate {
514    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
515        match self {
516            Predicate::AlwaysTrue => {
517                write!(f, "TRUE")
518            }
519            Predicate::AlwaysFalse => {
520                write!(f, "FALSE")
521            }
522            Predicate::And(expr) => {
523                write!(f, "({}) AND ({})", expr.inputs()[0], expr.inputs()[1])
524            }
525            Predicate::Or(expr) => {
526                write!(f, "({}) OR ({})", expr.inputs()[0], expr.inputs()[1])
527            }
528            Predicate::Not(expr) => {
529                write!(f, "NOT ({})", expr.inputs()[0])
530            }
531            Predicate::Unary(expr) => {
532                write!(f, "{expr}")
533            }
534            Predicate::Binary(expr) => {
535                write!(f, "{expr}")
536            }
537            Predicate::Set(expr) => {
538                write!(f, "{expr}")
539            }
540        }
541    }
542}
543
544impl Predicate {
545    /// Combines two predicates with `AND`.
546    ///
547    /// # Example
548    ///
549    /// ```rust
550    /// use std::ops::Bound::Unbounded;
551    ///
552    /// use iceberg::expr::BoundPredicate::Unary;
553    /// use iceberg::expr::Reference;
554    /// use iceberg::spec::Datum;
555    /// let expr1 = Reference::new("a").less_than(Datum::long(10));
556    ///
557    /// let expr2 = Reference::new("b").less_than(Datum::long(20));
558    ///
559    /// let expr = expr1.and(expr2);
560    ///
561    /// assert_eq!(&format!("{expr}"), "(a < 10) AND (b < 20)");
562    /// ```
563    pub fn and(self, other: Predicate) -> Predicate {
564        match (self, other) {
565            (Predicate::AlwaysFalse, _) => Predicate::AlwaysFalse,
566            (_, Predicate::AlwaysFalse) => Predicate::AlwaysFalse,
567            (Predicate::AlwaysTrue, rhs) => rhs,
568            (lhs, Predicate::AlwaysTrue) => lhs,
569            (lhs, rhs) => Predicate::And(LogicalExpression::new([Box::new(lhs), Box::new(rhs)])),
570        }
571    }
572
573    /// Combines two predicates with `OR`.
574    ///
575    /// # Example
576    ///
577    /// ```rust
578    /// use std::ops::Bound::Unbounded;
579    ///
580    /// use iceberg::expr::BoundPredicate::Unary;
581    /// use iceberg::expr::Reference;
582    /// use iceberg::spec::Datum;
583    /// let expr1 = Reference::new("a").less_than(Datum::long(10));
584    ///
585    /// let expr2 = Reference::new("b").less_than(Datum::long(20));
586    ///
587    /// let expr = expr1.or(expr2);
588    ///
589    /// assert_eq!(&format!("{expr}"), "(a < 10) OR (b < 20)");
590    /// ```
591    pub fn or(self, other: Predicate) -> Predicate {
592        match (self, other) {
593            (Predicate::AlwaysTrue, _) => Predicate::AlwaysTrue,
594            (_, Predicate::AlwaysTrue) => Predicate::AlwaysTrue,
595            (Predicate::AlwaysFalse, rhs) => rhs,
596            (lhs, Predicate::AlwaysFalse) => lhs,
597            (lhs, rhs) => Predicate::Or(LogicalExpression::new([Box::new(lhs), Box::new(rhs)])),
598        }
599    }
600
601    /// Returns a predicate representing the negation ('NOT') of this one,
602    /// by using inverse predicates rather than wrapping in a `NOT`.
603    /// Used for `NOT` elimination.
604    ///
605    /// # Example
606    ///
607    /// ```rust
608    /// use std::ops::Bound::Unbounded;
609    ///
610    /// use iceberg::expr::BoundPredicate::Unary;
611    /// use iceberg::expr::{LogicalExpression, Predicate, Reference};
612    /// use iceberg::spec::Datum;
613    /// let expr1 = Reference::new("a").less_than(Datum::long(10));
614    /// let expr2 = Reference::new("b")
615    ///     .less_than(Datum::long(5))
616    ///     .and(Reference::new("c").less_than(Datum::long(10)));
617    ///
618    /// let result = expr1.negate();
619    /// assert_eq!(&format!("{result}"), "a >= 10");
620    ///
621    /// let result = expr2.negate();
622    /// assert_eq!(&format!("{result}"), "(b >= 5) OR (c >= 10)");
623    /// ```
624    pub fn negate(self) -> Predicate {
625        match self {
626            Predicate::AlwaysTrue => Predicate::AlwaysFalse,
627            Predicate::AlwaysFalse => Predicate::AlwaysTrue,
628            Predicate::And(expr) => Predicate::Or(LogicalExpression::new(
629                expr.inputs.map(|expr| Box::new(expr.negate())),
630            )),
631            Predicate::Or(expr) => Predicate::And(LogicalExpression::new(
632                expr.inputs.map(|expr| Box::new(expr.negate())),
633            )),
634            Predicate::Not(expr) => {
635                let LogicalExpression { inputs: [input_0] } = expr;
636                *input_0
637            }
638            Predicate::Unary(expr) => {
639                Predicate::Unary(UnaryExpression::new(expr.op.negate(), expr.term))
640            }
641            Predicate::Binary(expr) => Predicate::Binary(BinaryExpression::new(
642                expr.op.negate(),
643                expr.term,
644                expr.literal,
645            )),
646            Predicate::Set(expr) => Predicate::Set(SetExpression::new(
647                expr.op.negate(),
648                expr.term,
649                expr.literals,
650            )),
651        }
652    }
653    /// Simplifies the expression by removing `NOT` predicates,
654    /// directly negating the inner expressions instead. The transformation
655    /// applies logical laws (such as De Morgan's laws) to
656    /// recursively negate and simplify inner expressions within `NOT`
657    /// predicates.
658    ///
659    /// # Example
660    ///
661    /// ```rust
662    /// use std::ops::Not;
663    ///
664    /// use iceberg::expr::{LogicalExpression, Predicate, Reference};
665    /// use iceberg::spec::Datum;
666    ///
667    /// let expression = Reference::new("a").less_than(Datum::long(5)).not();
668    /// let result = expression.rewrite_not();
669    ///
670    /// assert_eq!(&format!("{result}"), "a >= 5");
671    /// ```
672    pub fn rewrite_not(self) -> Predicate {
673        visit(&mut RewriteNotVisitor::new(), &self)
674            .expect("RewriteNotVisitor guarantees always success")
675    }
676}
677
678impl Not for Predicate {
679    type Output = Predicate;
680
681    /// Create a predicate which is the reverse of this predicate. For example: `NOT (a > 10)`.
682    ///
683    /// This is different from [`Predicate::negate()`] since it doesn't rewrite expression, but
684    /// just adds a `NOT` operator.
685    ///
686    /// # Example
687    ///     
688    ///```rust
689    /// use std::ops::Bound::Unbounded;
690    ///
691    /// use iceberg::expr::BoundPredicate::Unary;
692    /// use iceberg::expr::Reference;
693    /// use iceberg::spec::Datum;
694    /// let expr1 = Reference::new("a").less_than(Datum::long(10));
695    ///
696    /// let expr = !expr1;
697    ///
698    /// assert_eq!(&format!("{expr}"), "NOT (a < 10)");
699    /// ```
700    fn not(self) -> Self::Output {
701        Predicate::Not(LogicalExpression::new([Box::new(self)]))
702    }
703}
704
705/// Bound predicate expression after binding to a schema.
706#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
707pub enum BoundPredicate {
708    /// An expression always evaluates to true.
709    AlwaysTrue,
710    /// An expression always evaluates to false.
711    AlwaysFalse,
712    /// An expression combined by `AND`, for example, `a > 10 AND b < 20`.
713    And(LogicalExpression<BoundPredicate, 2>),
714    /// An expression combined by `OR`, for example, `a > 10 OR b < 20`.
715    Or(LogicalExpression<BoundPredicate, 2>),
716    /// An expression combined by `NOT`, for example, `NOT (a > 10)`.
717    Not(LogicalExpression<BoundPredicate, 1>),
718    /// Unary expression, for example, `a IS NULL`.
719    Unary(UnaryExpression<BoundReference>),
720    /// Binary expression, for example, `a > 10`.
721    Binary(BinaryExpression<BoundReference>),
722    /// Set predicates, for example, `a IN (1, 2, 3)`.
723    Set(SetExpression<BoundReference>),
724}
725
726impl BoundPredicate {
727    pub(crate) fn and(self, other: BoundPredicate) -> BoundPredicate {
728        BoundPredicate::And(LogicalExpression::new([Box::new(self), Box::new(other)]))
729    }
730
731    pub(crate) fn or(self, other: BoundPredicate) -> BoundPredicate {
732        BoundPredicate::Or(LogicalExpression::new([Box::new(self), Box::new(other)]))
733    }
734
735    pub(crate) fn negate(self) -> BoundPredicate {
736        match self {
737            BoundPredicate::AlwaysTrue => BoundPredicate::AlwaysFalse,
738            BoundPredicate::AlwaysFalse => BoundPredicate::AlwaysTrue,
739            BoundPredicate::And(expr) => BoundPredicate::Or(LogicalExpression::new(
740                expr.inputs.map(|expr| Box::new(expr.negate())),
741            )),
742            BoundPredicate::Or(expr) => BoundPredicate::And(LogicalExpression::new(
743                expr.inputs.map(|expr| Box::new(expr.negate())),
744            )),
745            BoundPredicate::Not(expr) => {
746                let LogicalExpression { inputs: [input_0] } = expr;
747                *input_0
748            }
749            BoundPredicate::Unary(expr) => {
750                BoundPredicate::Unary(UnaryExpression::new(expr.op.negate(), expr.term))
751            }
752            BoundPredicate::Binary(expr) => BoundPredicate::Binary(BinaryExpression::new(
753                expr.op.negate(),
754                expr.term,
755                expr.literal,
756            )),
757            BoundPredicate::Set(expr) => BoundPredicate::Set(SetExpression::new(
758                expr.op.negate(),
759                expr.term,
760                expr.literals,
761            )),
762        }
763    }
764
765    /// Simplifies the expression by removing `NOT` predicates,
766    /// directly negating the inner expressions instead. The transformation
767    /// applies logical laws (such as De Morgan's laws) to
768    /// recursively negate and simplify inner expressions within `NOT`
769    /// predicates.
770    ///
771    /// # Example
772    ///
773    /// ```rust
774    /// use std::ops::Not;
775    ///
776    /// use iceberg::expr::{Bind, BoundPredicate, Reference};
777    /// use iceberg::spec::Datum;
778    ///
779    /// // This would need to be bound first, but the concept is:
780    /// // let expression = bound_predicate.not();
781    /// // let result = expression.rewrite_not();
782    /// ```
783    pub fn rewrite_not(self) -> BoundPredicate {
784        visit_bound(&mut RewriteNotVisitor::new(), &self)
785            .expect("RewriteNotVisitor guarantees always success")
786    }
787}
788
789impl Display for BoundPredicate {
790    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
791        match self {
792            BoundPredicate::AlwaysTrue => {
793                write!(f, "True")
794            }
795            BoundPredicate::AlwaysFalse => {
796                write!(f, "False")
797            }
798            BoundPredicate::And(expr) => {
799                write!(f, "({}) AND ({})", expr.inputs()[0], expr.inputs()[1])
800            }
801            BoundPredicate::Or(expr) => {
802                write!(f, "({}) OR ({})", expr.inputs()[0], expr.inputs()[1])
803            }
804            BoundPredicate::Not(expr) => {
805                write!(f, "NOT ({})", expr.inputs()[0])
806            }
807            BoundPredicate::Unary(expr) => {
808                write!(f, "{expr}")
809            }
810            BoundPredicate::Binary(expr) => {
811                write!(f, "{expr}")
812            }
813            BoundPredicate::Set(expr) => {
814                write!(f, "{expr}")
815            }
816        }
817    }
818}
819
820#[cfg(test)]
821mod tests {
822    use std::ops::Not;
823    use std::sync::Arc;
824
825    use crate::expr::Predicate::{AlwaysFalse, AlwaysTrue};
826    use crate::expr::{Bind, BoundPredicate, Reference};
827    use crate::spec::{Datum, NestedField, PrimitiveType, Schema, SchemaRef, Type};
828
829    #[test]
830    fn test_logical_or_rewrite_not() {
831        let expression = Reference::new("b")
832            .less_than(Datum::long(5))
833            .or(Reference::new("c").less_than(Datum::long(10)))
834            .not();
835
836        let expected = Reference::new("b")
837            .greater_than_or_equal_to(Datum::long(5))
838            .and(Reference::new("c").greater_than_or_equal_to(Datum::long(10)));
839
840        let result = expression.rewrite_not();
841
842        assert_eq!(result, expected);
843    }
844
845    #[test]
846    fn test_logical_and_rewrite_not() {
847        let expression = Reference::new("b")
848            .less_than(Datum::long(5))
849            .and(Reference::new("c").less_than(Datum::long(10)))
850            .not();
851
852        let expected = Reference::new("b")
853            .greater_than_or_equal_to(Datum::long(5))
854            .or(Reference::new("c").greater_than_or_equal_to(Datum::long(10)));
855
856        let result = expression.rewrite_not();
857
858        assert_eq!(result, expected);
859    }
860
861    #[test]
862    fn test_set_rewrite_not() {
863        let expression = Reference::new("a")
864            .is_in([Datum::int(5), Datum::int(6)])
865            .not();
866
867        let expected = Reference::new("a").is_not_in([Datum::int(5), Datum::int(6)]);
868
869        let result = expression.rewrite_not();
870
871        assert_eq!(result, expected);
872    }
873
874    #[test]
875    fn test_binary_rewrite_not() {
876        let expression = Reference::new("a").less_than(Datum::long(5)).not();
877
878        let expected = Reference::new("a").greater_than_or_equal_to(Datum::long(5));
879
880        let result = expression.rewrite_not();
881
882        assert_eq!(result, expected);
883    }
884
885    #[test]
886    fn test_unary_rewrite_not() {
887        let expression = Reference::new("a").is_null().not();
888
889        let expected = Reference::new("a").is_not_null();
890
891        let result = expression.rewrite_not();
892
893        assert_eq!(result, expected);
894    }
895
896    #[test]
897    fn test_predicate_and_reduce_always_true_false() {
898        let true_or_expr = AlwaysTrue.and(Reference::new("b").less_than(Datum::long(5)));
899        assert_eq!(&format!("{true_or_expr}"), "b < 5");
900
901        let expr_or_true = Reference::new("b")
902            .less_than(Datum::long(5))
903            .and(AlwaysTrue);
904        assert_eq!(&format!("{expr_or_true}"), "b < 5");
905
906        let false_or_expr = AlwaysFalse.and(Reference::new("b").less_than(Datum::long(5)));
907        assert_eq!(&format!("{false_or_expr}"), "FALSE");
908
909        let expr_or_false = Reference::new("b")
910            .less_than(Datum::long(5))
911            .and(AlwaysFalse);
912        assert_eq!(&format!("{expr_or_false}"), "FALSE");
913    }
914
915    #[test]
916    fn test_predicate_or_reduce_always_true_false() {
917        let true_or_expr = AlwaysTrue.or(Reference::new("b").less_than(Datum::long(5)));
918        assert_eq!(&format!("{true_or_expr}"), "TRUE");
919
920        let expr_or_true = Reference::new("b").less_than(Datum::long(5)).or(AlwaysTrue);
921        assert_eq!(&format!("{expr_or_true}"), "TRUE");
922
923        let false_or_expr = AlwaysFalse.or(Reference::new("b").less_than(Datum::long(5)));
924        assert_eq!(&format!("{false_or_expr}"), "b < 5");
925
926        let expr_or_false = Reference::new("b")
927            .less_than(Datum::long(5))
928            .or(AlwaysFalse);
929        assert_eq!(&format!("{expr_or_false}"), "b < 5");
930    }
931
932    #[test]
933    fn test_predicate_negate_and() {
934        let expression = Reference::new("b")
935            .less_than(Datum::long(5))
936            .and(Reference::new("c").less_than(Datum::long(10)));
937
938        let expected = Reference::new("b")
939            .greater_than_or_equal_to(Datum::long(5))
940            .or(Reference::new("c").greater_than_or_equal_to(Datum::long(10)));
941
942        let result = expression.negate();
943
944        assert_eq!(result, expected);
945    }
946
947    #[test]
948    fn test_predicate_negate_or() {
949        let expression = Reference::new("b")
950            .greater_than_or_equal_to(Datum::long(5))
951            .or(Reference::new("c").greater_than_or_equal_to(Datum::long(10)));
952
953        let expected = Reference::new("b")
954            .less_than(Datum::long(5))
955            .and(Reference::new("c").less_than(Datum::long(10)));
956
957        let result = expression.negate();
958
959        assert_eq!(result, expected);
960    }
961
962    #[test]
963    fn test_predicate_negate_not() {
964        let expression = Reference::new("b")
965            .greater_than_or_equal_to(Datum::long(5))
966            .not();
967
968        let expected = Reference::new("b").greater_than_or_equal_to(Datum::long(5));
969
970        let result = expression.negate();
971
972        assert_eq!(result, expected);
973    }
974
975    #[test]
976    fn test_predicate_negate_unary() {
977        let expression = Reference::new("b").is_not_null();
978
979        let expected = Reference::new("b").is_null();
980
981        let result = expression.negate();
982
983        assert_eq!(result, expected);
984    }
985
986    #[test]
987    fn test_predicate_negate_binary() {
988        let expression = Reference::new("a").less_than(Datum::long(5));
989
990        let expected = Reference::new("a").greater_than_or_equal_to(Datum::long(5));
991
992        let result = expression.negate();
993
994        assert_eq!(result, expected);
995    }
996
997    #[test]
998    fn test_predicate_negate_set() {
999        let expression = Reference::new("a").is_in([Datum::long(5), Datum::long(6)]);
1000
1001        let expected = Reference::new("a").is_not_in([Datum::long(5), Datum::long(6)]);
1002
1003        let result = expression.negate();
1004
1005        assert_eq!(result, expected);
1006    }
1007
1008    pub fn table_schema_simple() -> SchemaRef {
1009        Arc::new(
1010            Schema::builder()
1011                .with_schema_id(1)
1012                .with_identifier_field_ids(vec![2])
1013                .with_fields(vec![
1014                    NestedField::optional(1, "foo", Type::Primitive(PrimitiveType::String)).into(),
1015                    NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(),
1016                    NestedField::optional(3, "baz", Type::Primitive(PrimitiveType::Boolean)).into(),
1017                    NestedField::optional(4, "qux", Type::Primitive(PrimitiveType::Float)).into(),
1018                ])
1019                .build()
1020                .unwrap(),
1021        )
1022    }
1023
1024    fn test_bound_predicate_serialize_diserialize(bound_predicate: BoundPredicate) {
1025        let serialized = serde_json::to_string(&bound_predicate).unwrap();
1026        let deserialized: BoundPredicate = serde_json::from_str(&serialized).unwrap();
1027        assert_eq!(bound_predicate, deserialized);
1028    }
1029
1030    #[test]
1031    fn test_bind_is_null() {
1032        let schema = table_schema_simple();
1033        let expr = Reference::new("foo").is_null();
1034        let bound_expr = expr.bind(schema, true).unwrap();
1035        assert_eq!(&format!("{bound_expr}"), "foo IS NULL");
1036        test_bound_predicate_serialize_diserialize(bound_expr);
1037    }
1038
1039    #[test]
1040    fn test_bind_is_null_required() {
1041        let schema = table_schema_simple();
1042        let expr = Reference::new("bar").is_null();
1043        let bound_expr = expr.bind(schema, true).unwrap();
1044        assert_eq!(&format!("{bound_expr}"), "False");
1045        test_bound_predicate_serialize_diserialize(bound_expr);
1046    }
1047
1048    #[test]
1049    fn test_bind_is_not_null() {
1050        let schema = table_schema_simple();
1051        let expr = Reference::new("foo").is_not_null();
1052        let bound_expr = expr.bind(schema, true).unwrap();
1053        assert_eq!(&format!("{bound_expr}"), "foo IS NOT NULL");
1054        test_bound_predicate_serialize_diserialize(bound_expr);
1055    }
1056
1057    #[test]
1058    fn test_bind_is_not_null_required() {
1059        let schema = table_schema_simple();
1060        let expr = Reference::new("bar").is_not_null();
1061        let bound_expr = expr.bind(schema, true).unwrap();
1062        assert_eq!(&format!("{bound_expr}"), "True");
1063        test_bound_predicate_serialize_diserialize(bound_expr);
1064    }
1065
1066    #[test]
1067    fn test_bind_is_nan() {
1068        let schema = table_schema_simple();
1069        let expr = Reference::new("qux").is_nan();
1070        let bound_expr = expr.bind(schema, true).unwrap();
1071        assert_eq!(&format!("{bound_expr}"), "qux IS NAN");
1072
1073        let schema_string = table_schema_simple();
1074        let expr_string = Reference::new("foo").is_nan();
1075        let bound_expr_string = expr_string.bind(schema_string, true);
1076        assert!(bound_expr_string.is_err());
1077        test_bound_predicate_serialize_diserialize(bound_expr);
1078    }
1079
1080    #[test]
1081    fn test_bind_is_nan_wrong_type() {
1082        let schema = table_schema_simple();
1083        let expr = Reference::new("foo").is_nan();
1084        let bound_expr = expr.bind(schema, true);
1085        assert!(bound_expr.is_err());
1086    }
1087
1088    #[test]
1089    fn test_bind_is_not_nan() {
1090        let schema = table_schema_simple();
1091        let expr = Reference::new("qux").is_not_nan();
1092        let bound_expr = expr.bind(schema, true).unwrap();
1093        assert_eq!(&format!("{bound_expr}"), "qux IS NOT NAN");
1094        test_bound_predicate_serialize_diserialize(bound_expr);
1095    }
1096
1097    #[test]
1098    fn test_bind_is_not_nan_wrong_type() {
1099        let schema = table_schema_simple();
1100        let expr = Reference::new("foo").is_not_nan();
1101        let bound_expr = expr.bind(schema, true);
1102        assert!(bound_expr.is_err());
1103    }
1104
1105    #[test]
1106    fn test_bind_less_than() {
1107        let schema = table_schema_simple();
1108        let expr = Reference::new("bar").less_than(Datum::int(10));
1109        let bound_expr = expr.bind(schema, true).unwrap();
1110        assert_eq!(&format!("{bound_expr}"), "bar < 10");
1111        test_bound_predicate_serialize_diserialize(bound_expr);
1112    }
1113
1114    #[test]
1115    fn test_bind_less_than_wrong_type() {
1116        let schema = table_schema_simple();
1117        let expr = Reference::new("bar").less_than(Datum::string("abcd"));
1118        let bound_expr = expr.bind(schema, true);
1119        assert!(bound_expr.is_err());
1120    }
1121
1122    #[test]
1123    fn test_bind_less_than_or_eq() {
1124        let schema = table_schema_simple();
1125        let expr = Reference::new("bar").less_than_or_equal_to(Datum::int(10));
1126        let bound_expr = expr.bind(schema, true).unwrap();
1127        assert_eq!(&format!("{bound_expr}"), "bar <= 10");
1128        test_bound_predicate_serialize_diserialize(bound_expr);
1129    }
1130
1131    #[test]
1132    fn test_bind_less_than_or_eq_wrong_type() {
1133        let schema = table_schema_simple();
1134        let expr = Reference::new("bar").less_than_or_equal_to(Datum::string("abcd"));
1135        let bound_expr = expr.bind(schema, true);
1136        assert!(bound_expr.is_err());
1137    }
1138
1139    #[test]
1140    fn test_bind_greater_than() {
1141        let schema = table_schema_simple();
1142        let expr = Reference::new("bar").greater_than(Datum::int(10));
1143        let bound_expr = expr.bind(schema, true).unwrap();
1144        assert_eq!(&format!("{bound_expr}"), "bar > 10");
1145        test_bound_predicate_serialize_diserialize(bound_expr);
1146    }
1147
1148    #[test]
1149    fn test_bind_greater_than_wrong_type() {
1150        let schema = table_schema_simple();
1151        let expr = Reference::new("bar").greater_than(Datum::string("abcd"));
1152        let bound_expr = expr.bind(schema, true);
1153        assert!(bound_expr.is_err());
1154    }
1155
1156    #[test]
1157    fn test_bind_greater_than_or_eq() {
1158        let schema = table_schema_simple();
1159        let expr = Reference::new("bar").greater_than_or_equal_to(Datum::int(10));
1160        let bound_expr = expr.bind(schema, true).unwrap();
1161        assert_eq!(&format!("{bound_expr}"), "bar >= 10");
1162        test_bound_predicate_serialize_diserialize(bound_expr);
1163    }
1164
1165    #[test]
1166    fn test_bind_greater_than_or_eq_wrong_type() {
1167        let schema = table_schema_simple();
1168        let expr = Reference::new("bar").greater_than_or_equal_to(Datum::string("abcd"));
1169        let bound_expr = expr.bind(schema, true);
1170        assert!(bound_expr.is_err());
1171    }
1172
1173    #[test]
1174    fn test_bind_equal_to() {
1175        let schema = table_schema_simple();
1176        let expr = Reference::new("bar").equal_to(Datum::int(10));
1177        let bound_expr = expr.bind(schema, true).unwrap();
1178        assert_eq!(&format!("{bound_expr}"), "bar = 10");
1179        test_bound_predicate_serialize_diserialize(bound_expr);
1180    }
1181
1182    #[test]
1183    fn test_bind_equal_to_above_max() {
1184        let schema = table_schema_simple();
1185        // int32 can hold up to 2147483647
1186        let expr = Reference::new("bar").equal_to(Datum::long(2147483648i64));
1187        let bound_expr = expr.bind(schema, true).unwrap();
1188        assert_eq!(&format!("{bound_expr}"), "False");
1189        test_bound_predicate_serialize_diserialize(bound_expr);
1190    }
1191
1192    #[test]
1193    fn test_bind_equal_to_below_min() {
1194        let schema = table_schema_simple();
1195        // int32 can hold up to -2147483647
1196        let expr = Reference::new("bar").equal_to(Datum::long(-2147483649i64));
1197        let bound_expr = expr.bind(schema, true).unwrap();
1198        assert_eq!(&format!("{bound_expr}"), "False");
1199        test_bound_predicate_serialize_diserialize(bound_expr);
1200    }
1201
1202    #[test]
1203    fn test_bind_not_equal_to_above_max() {
1204        let schema = table_schema_simple();
1205        // int32 can hold up to 2147483647
1206        let expr = Reference::new("bar").not_equal_to(Datum::long(2147483648i64));
1207        let bound_expr = expr.bind(schema, true).unwrap();
1208        assert_eq!(&format!("{bound_expr}"), "True");
1209        test_bound_predicate_serialize_diserialize(bound_expr);
1210    }
1211
1212    #[test]
1213    fn test_bind_not_equal_to_below_min() {
1214        let schema = table_schema_simple();
1215        // int32 can hold up to -2147483647
1216        let expr = Reference::new("bar").not_equal_to(Datum::long(-2147483649i64));
1217        let bound_expr = expr.bind(schema, true).unwrap();
1218        assert_eq!(&format!("{bound_expr}"), "True");
1219        test_bound_predicate_serialize_diserialize(bound_expr);
1220    }
1221
1222    #[test]
1223    fn test_bind_less_than_above_max() {
1224        let schema = table_schema_simple();
1225        // int32 can hold up to 2147483647
1226        let expr = Reference::new("bar").less_than(Datum::long(2147483648i64));
1227        let bound_expr = expr.bind(schema, true).unwrap();
1228        assert_eq!(&format!("{bound_expr}"), "True");
1229        test_bound_predicate_serialize_diserialize(bound_expr);
1230    }
1231
1232    #[test]
1233    fn test_bind_less_than_below_min() {
1234        let schema = table_schema_simple();
1235        // int32 can hold up to -2147483647
1236        let expr = Reference::new("bar").less_than(Datum::long(-2147483649i64));
1237        let bound_expr = expr.bind(schema, true).unwrap();
1238        assert_eq!(&format!("{bound_expr}"), "False");
1239        test_bound_predicate_serialize_diserialize(bound_expr);
1240    }
1241
1242    #[test]
1243    fn test_bind_less_than_or_equal_to_above_max() {
1244        let schema = table_schema_simple();
1245        // int32 can hold up to 2147483647
1246        let expr = Reference::new("bar").less_than_or_equal_to(Datum::long(2147483648i64));
1247        let bound_expr = expr.bind(schema, true).unwrap();
1248        assert_eq!(&format!("{bound_expr}"), "True");
1249        test_bound_predicate_serialize_diserialize(bound_expr);
1250    }
1251
1252    #[test]
1253    fn test_bind_less_than_or_equal_to_below_min() {
1254        let schema = table_schema_simple();
1255        // int32 can hold up to -2147483647
1256        let expr = Reference::new("bar").less_than_or_equal_to(Datum::long(-2147483649i64));
1257        let bound_expr = expr.bind(schema, true).unwrap();
1258        assert_eq!(&format!("{bound_expr}"), "False");
1259        test_bound_predicate_serialize_diserialize(bound_expr);
1260    }
1261
1262    #[test]
1263    fn test_bind_great_than_above_max() {
1264        let schema = table_schema_simple();
1265        // int32 can hold up to 2147483647
1266        let expr = Reference::new("bar").greater_than(Datum::long(2147483648i64));
1267        let bound_expr = expr.bind(schema, true).unwrap();
1268        assert_eq!(&format!("{bound_expr}"), "False");
1269        test_bound_predicate_serialize_diserialize(bound_expr);
1270    }
1271
1272    #[test]
1273    fn test_bind_great_than_below_min() {
1274        let schema = table_schema_simple();
1275        // int32 can hold up to -2147483647
1276        let expr = Reference::new("bar").greater_than(Datum::long(-2147483649i64));
1277        let bound_expr = expr.bind(schema, true).unwrap();
1278        assert_eq!(&format!("{bound_expr}"), "True");
1279        test_bound_predicate_serialize_diserialize(bound_expr);
1280    }
1281
1282    #[test]
1283    fn test_bind_great_than_or_equal_to_above_max() {
1284        let schema = table_schema_simple();
1285        // int32 can hold up to 2147483647
1286        let expr = Reference::new("bar").greater_than_or_equal_to(Datum::long(2147483648i64));
1287        let bound_expr = expr.bind(schema, true).unwrap();
1288        assert_eq!(&format!("{bound_expr}"), "False");
1289        test_bound_predicate_serialize_diserialize(bound_expr);
1290    }
1291
1292    #[test]
1293    fn test_bind_great_than_or_equal_to_below_min() {
1294        let schema = table_schema_simple();
1295        // int32 can hold up to -2147483647
1296        let expr = Reference::new("bar").greater_than_or_equal_to(Datum::long(-2147483649i64));
1297        let bound_expr = expr.bind(schema, true).unwrap();
1298        assert_eq!(&format!("{bound_expr}"), "True");
1299        test_bound_predicate_serialize_diserialize(bound_expr);
1300    }
1301
1302    #[test]
1303    fn test_bind_equal_to_wrong_type() {
1304        let schema = table_schema_simple();
1305        let expr = Reference::new("bar").equal_to(Datum::string("abcd"));
1306        let bound_expr = expr.bind(schema, true);
1307        assert!(bound_expr.is_err());
1308    }
1309
1310    #[test]
1311    fn test_bind_not_equal_to() {
1312        let schema = table_schema_simple();
1313        let expr = Reference::new("bar").not_equal_to(Datum::int(10));
1314        let bound_expr = expr.bind(schema, true).unwrap();
1315        assert_eq!(&format!("{bound_expr}"), "bar != 10");
1316        test_bound_predicate_serialize_diserialize(bound_expr);
1317    }
1318
1319    #[test]
1320    fn test_bind_not_equal_to_wrong_type() {
1321        let schema = table_schema_simple();
1322        let expr = Reference::new("bar").not_equal_to(Datum::string("abcd"));
1323        let bound_expr = expr.bind(schema, true);
1324        assert!(bound_expr.is_err());
1325    }
1326
1327    #[test]
1328    fn test_bind_starts_with() {
1329        let schema = table_schema_simple();
1330        let expr = Reference::new("foo").starts_with(Datum::string("abcd"));
1331        let bound_expr = expr.bind(schema, true).unwrap();
1332        assert_eq!(&format!("{bound_expr}"), r#"foo STARTS WITH "abcd""#);
1333        test_bound_predicate_serialize_diserialize(bound_expr);
1334    }
1335
1336    #[test]
1337    fn test_bind_starts_with_wrong_type() {
1338        let schema = table_schema_simple();
1339        let expr = Reference::new("bar").starts_with(Datum::string("abcd"));
1340        let bound_expr = expr.bind(schema, true);
1341        assert!(bound_expr.is_err());
1342    }
1343
1344    #[test]
1345    fn test_bind_not_starts_with() {
1346        let schema = table_schema_simple();
1347        let expr = Reference::new("foo").not_starts_with(Datum::string("abcd"));
1348        let bound_expr = expr.bind(schema, true).unwrap();
1349        assert_eq!(&format!("{bound_expr}"), r#"foo NOT STARTS WITH "abcd""#);
1350        test_bound_predicate_serialize_diserialize(bound_expr);
1351    }
1352
1353    #[test]
1354    fn test_bind_not_starts_with_wrong_type() {
1355        let schema = table_schema_simple();
1356        let expr = Reference::new("bar").not_starts_with(Datum::string("abcd"));
1357        let bound_expr = expr.bind(schema, true);
1358        assert!(bound_expr.is_err());
1359    }
1360
1361    #[test]
1362    fn test_bind_in() {
1363        let schema = table_schema_simple();
1364        let expr = Reference::new("bar").is_in([Datum::int(10), Datum::int(20)]);
1365        let bound_expr = expr.bind(schema, true).unwrap();
1366        assert_eq!(&format!("{bound_expr}"), "bar IN (20, 10)");
1367        test_bound_predicate_serialize_diserialize(bound_expr);
1368    }
1369
1370    #[test]
1371    fn test_bind_in_empty() {
1372        let schema = table_schema_simple();
1373        let expr = Reference::new("bar").is_in(vec![]);
1374        let bound_expr = expr.bind(schema, true).unwrap();
1375        assert_eq!(&format!("{bound_expr}"), "False");
1376        test_bound_predicate_serialize_diserialize(bound_expr);
1377    }
1378
1379    #[test]
1380    fn test_bind_in_one_literal() {
1381        let schema = table_schema_simple();
1382        let expr = Reference::new("bar").is_in(vec![Datum::int(10)]);
1383        let bound_expr = expr.bind(schema, true).unwrap();
1384        assert_eq!(&format!("{bound_expr}"), "bar = 10");
1385        test_bound_predicate_serialize_diserialize(bound_expr);
1386    }
1387
1388    #[test]
1389    fn test_bind_in_wrong_type() {
1390        let schema = table_schema_simple();
1391        let expr = Reference::new("bar").is_in(vec![Datum::int(10), Datum::string("abcd")]);
1392        let bound_expr = expr.bind(schema, true);
1393        assert!(bound_expr.is_err());
1394    }
1395
1396    #[test]
1397    fn test_bind_not_in() {
1398        let schema = table_schema_simple();
1399        let expr = Reference::new("bar").is_not_in([Datum::int(10), Datum::int(20)]);
1400        let bound_expr = expr.bind(schema, true).unwrap();
1401        assert_eq!(&format!("{bound_expr}"), "bar NOT IN (20, 10)");
1402        test_bound_predicate_serialize_diserialize(bound_expr);
1403    }
1404
1405    #[test]
1406    fn test_bind_not_in_empty() {
1407        let schema = table_schema_simple();
1408        let expr = Reference::new("bar").is_not_in(vec![]);
1409        let bound_expr = expr.bind(schema, true).unwrap();
1410        assert_eq!(&format!("{bound_expr}"), "True");
1411        test_bound_predicate_serialize_diserialize(bound_expr);
1412    }
1413
1414    #[test]
1415    fn test_bind_not_in_one_literal() {
1416        let schema = table_schema_simple();
1417        let expr = Reference::new("bar").is_not_in(vec![Datum::int(10)]);
1418        let bound_expr = expr.bind(schema, true).unwrap();
1419        assert_eq!(&format!("{bound_expr}"), "bar != 10");
1420        test_bound_predicate_serialize_diserialize(bound_expr);
1421    }
1422
1423    #[test]
1424    fn test_bind_not_in_wrong_type() {
1425        let schema = table_schema_simple();
1426        let expr = Reference::new("bar").is_not_in([Datum::int(10), Datum::string("abcd")]);
1427        let bound_expr = expr.bind(schema, true);
1428        assert!(bound_expr.is_err());
1429    }
1430
1431    #[test]
1432    fn test_bind_and() {
1433        let schema = table_schema_simple();
1434        let expr = Reference::new("bar")
1435            .less_than(Datum::int(10))
1436            .and(Reference::new("foo").is_null());
1437        let bound_expr = expr.bind(schema, true).unwrap();
1438        assert_eq!(&format!("{bound_expr}"), "(bar < 10) AND (foo IS NULL)");
1439        test_bound_predicate_serialize_diserialize(bound_expr);
1440    }
1441
1442    #[test]
1443    fn test_bind_and_always_false() {
1444        let schema = table_schema_simple();
1445        let expr = Reference::new("foo")
1446            .less_than(Datum::string("abcd"))
1447            .and(Reference::new("bar").is_null());
1448        let bound_expr = expr.bind(schema, true).unwrap();
1449        assert_eq!(&format!("{bound_expr}"), "False");
1450        test_bound_predicate_serialize_diserialize(bound_expr);
1451    }
1452
1453    #[test]
1454    fn test_bind_and_always_true() {
1455        let schema = table_schema_simple();
1456        let expr = Reference::new("foo")
1457            .less_than(Datum::string("abcd"))
1458            .and(Reference::new("bar").is_not_null());
1459        let bound_expr = expr.bind(schema, true).unwrap();
1460        assert_eq!(&format!("{bound_expr}"), r#"foo < "abcd""#);
1461        test_bound_predicate_serialize_diserialize(bound_expr);
1462    }
1463
1464    #[test]
1465    fn test_bind_or() {
1466        let schema = table_schema_simple();
1467        let expr = Reference::new("bar")
1468            .less_than(Datum::int(10))
1469            .or(Reference::new("foo").is_null());
1470        let bound_expr = expr.bind(schema, true).unwrap();
1471        assert_eq!(&format!("{bound_expr}"), "(bar < 10) OR (foo IS NULL)");
1472        test_bound_predicate_serialize_diserialize(bound_expr);
1473    }
1474
1475    #[test]
1476    fn test_bind_or_always_true() {
1477        let schema = table_schema_simple();
1478        let expr = Reference::new("foo")
1479            .less_than(Datum::string("abcd"))
1480            .or(Reference::new("bar").is_not_null());
1481        let bound_expr = expr.bind(schema, true).unwrap();
1482        assert_eq!(&format!("{bound_expr}"), "True");
1483        test_bound_predicate_serialize_diserialize(bound_expr);
1484    }
1485
1486    #[test]
1487    fn test_bind_or_always_false() {
1488        let schema = table_schema_simple();
1489        let expr = Reference::new("foo")
1490            .less_than(Datum::string("abcd"))
1491            .or(Reference::new("bar").is_null());
1492        let bound_expr = expr.bind(schema, true).unwrap();
1493        assert_eq!(&format!("{bound_expr}"), r#"foo < "abcd""#);
1494        test_bound_predicate_serialize_diserialize(bound_expr);
1495    }
1496
1497    #[test]
1498    fn test_bind_not() {
1499        let schema = table_schema_simple();
1500        let expr = !Reference::new("bar").less_than(Datum::int(10));
1501        let bound_expr = expr.bind(schema, true).unwrap();
1502        assert_eq!(&format!("{bound_expr}"), "NOT (bar < 10)");
1503        test_bound_predicate_serialize_diserialize(bound_expr);
1504    }
1505
1506    #[test]
1507    fn test_bind_not_always_true() {
1508        let schema = table_schema_simple();
1509        let expr = !Reference::new("bar").is_not_null();
1510        let bound_expr = expr.bind(schema, true).unwrap();
1511        assert_eq!(&format!("{bound_expr}"), "False");
1512        test_bound_predicate_serialize_diserialize(bound_expr);
1513    }
1514
1515    #[test]
1516    fn test_bind_not_always_false() {
1517        let schema = table_schema_simple();
1518        let expr = !Reference::new("bar").is_null();
1519        let bound_expr = expr.bind(schema, true).unwrap();
1520        assert_eq!(&format!("{bound_expr}"), r#"True"#);
1521        test_bound_predicate_serialize_diserialize(bound_expr);
1522    }
1523
1524    #[test]
1525    fn test_bound_predicate_rewrite_not_binary() {
1526        let schema = table_schema_simple();
1527
1528        // Test NOT elimination on binary predicates: NOT(bar < 10) => bar >= 10
1529        let predicate = Reference::new("bar").less_than(Datum::int(10)).not();
1530        let bound_predicate = predicate.bind(schema.clone(), true).unwrap();
1531        let result = bound_predicate.rewrite_not();
1532
1533        // The result should be bar >= 10
1534        let expected_predicate = Reference::new("bar").greater_than_or_equal_to(Datum::int(10));
1535        let expected_bound = expected_predicate.bind(schema, true).unwrap();
1536
1537        assert_eq!(result, expected_bound);
1538        assert_eq!(&format!("{result}"), "bar >= 10");
1539    }
1540
1541    #[test]
1542    fn test_bound_predicate_rewrite_not_unary() {
1543        let schema = table_schema_simple();
1544
1545        // Test NOT elimination on unary predicates: NOT(foo IS NULL) => foo IS NOT NULL
1546        let predicate = Reference::new("foo").is_null().not();
1547        let bound_predicate = predicate.bind(schema.clone(), true).unwrap();
1548        let result = bound_predicate.rewrite_not();
1549
1550        // The result should be foo IS NOT NULL
1551        let expected_predicate = Reference::new("foo").is_not_null();
1552        let expected_bound = expected_predicate.bind(schema, true).unwrap();
1553
1554        assert_eq!(result, expected_bound);
1555        assert_eq!(&format!("{result}"), "foo IS NOT NULL");
1556    }
1557
1558    #[test]
1559    fn test_bound_predicate_rewrite_not_set() {
1560        let schema = table_schema_simple();
1561
1562        // Test NOT elimination on set predicates: NOT(bar IN (10, 20)) => bar NOT IN (10, 20)
1563        let predicate = Reference::new("bar")
1564            .is_in([Datum::int(10), Datum::int(20)])
1565            .not();
1566        let bound_predicate = predicate.bind(schema.clone(), true).unwrap();
1567        let result = bound_predicate.rewrite_not();
1568
1569        // The result should be bar NOT IN (10, 20)
1570        let expected_predicate = Reference::new("bar").is_not_in([Datum::int(10), Datum::int(20)]);
1571        let expected_bound = expected_predicate.bind(schema, true).unwrap();
1572
1573        assert_eq!(result, expected_bound);
1574        // Note: HashSet order may vary, so we check that it contains the expected format
1575        let result_str = format!("{result}");
1576        assert!(
1577            result_str.contains("bar NOT IN")
1578                && result_str.contains("10")
1579                && result_str.contains("20")
1580        );
1581    }
1582
1583    #[test]
1584    fn test_bound_predicate_rewrite_not_and_demorgan() {
1585        let schema = table_schema_simple();
1586
1587        // Test De Morgan's law: NOT(A AND B) = (NOT A) OR (NOT B)
1588        // NOT((bar < 10) AND (foo IS NULL)) => (bar >= 10) OR (foo IS NOT NULL)
1589        let predicate = Reference::new("bar")
1590            .less_than(Datum::int(10))
1591            .and(Reference::new("foo").is_null())
1592            .not();
1593
1594        let bound_predicate = predicate.bind(schema.clone(), true).unwrap();
1595        let result = bound_predicate.rewrite_not();
1596
1597        // Expected: (bar >= 10) OR (foo IS NOT NULL)
1598        let expected_predicate = Reference::new("bar")
1599            .greater_than_or_equal_to(Datum::int(10))
1600            .or(Reference::new("foo").is_not_null());
1601
1602        let expected_bound = expected_predicate.bind(schema, true).unwrap();
1603
1604        assert_eq!(result, expected_bound);
1605        assert_eq!(&format!("{result}"), "(bar >= 10) OR (foo IS NOT NULL)");
1606    }
1607
1608    #[test]
1609    fn test_bound_predicate_rewrite_not_or_demorgan() {
1610        let schema = table_schema_simple();
1611
1612        // Test De Morgan's law: NOT(A OR B) = (NOT A) AND (NOT B)
1613        // NOT((bar < 10) OR (foo IS NULL)) => (bar >= 10) AND (foo IS NOT NULL)
1614        let predicate = Reference::new("bar")
1615            .less_than(Datum::int(10))
1616            .or(Reference::new("foo").is_null())
1617            .not();
1618
1619        let bound_predicate = predicate.bind(schema.clone(), true).unwrap();
1620        let result = bound_predicate.rewrite_not();
1621
1622        // Expected: (bar >= 10) AND (foo IS NOT NULL)
1623        let expected_predicate = Reference::new("bar")
1624            .greater_than_or_equal_to(Datum::int(10))
1625            .and(Reference::new("foo").is_not_null());
1626
1627        let expected_bound = expected_predicate.bind(schema, true).unwrap();
1628
1629        assert_eq!(result, expected_bound);
1630        assert_eq!(&format!("{result}"), "(bar >= 10) AND (foo IS NOT NULL)");
1631    }
1632
1633    #[test]
1634    fn test_bound_predicate_rewrite_not_double_negative() {
1635        let schema = table_schema_simple();
1636
1637        // Test double negative elimination: NOT(NOT(bar < 10)) => bar < 10
1638        let predicate = Reference::new("bar").less_than(Datum::int(10)).not().not();
1639        let bound_predicate = predicate.bind(schema.clone(), true).unwrap();
1640        let result = bound_predicate.rewrite_not();
1641
1642        // The result should be bar < 10 (original predicate)
1643        let expected_predicate = Reference::new("bar").less_than(Datum::int(10));
1644        let expected_bound = expected_predicate.bind(schema, true).unwrap();
1645
1646        assert_eq!(result, expected_bound);
1647        assert_eq!(&format!("{result}"), "bar < 10");
1648    }
1649
1650    #[test]
1651    fn test_bound_predicate_rewrite_not_always_true_false() {
1652        let schema = table_schema_simple();
1653
1654        // Test NOT(AlwaysTrue) => AlwaysFalse
1655        let predicate = Reference::new("bar").is_not_null().not(); // This becomes NOT(AlwaysTrue) since bar is required
1656        let bound_predicate = predicate.bind(schema.clone(), true).unwrap();
1657        let result = bound_predicate.rewrite_not();
1658
1659        assert_eq!(result, BoundPredicate::AlwaysFalse);
1660        assert_eq!(&format!("{result}"), "False");
1661
1662        // Test NOT(AlwaysFalse) => AlwaysTrue
1663        let predicate2 = Reference::new("bar").is_null().not(); // This becomes NOT(AlwaysFalse) since bar is required
1664        let bound_predicate2 = predicate2.bind(schema, true).unwrap();
1665        let result2 = bound_predicate2.rewrite_not();
1666
1667        assert_eq!(result2, BoundPredicate::AlwaysTrue);
1668        assert_eq!(&format!("{result2}"), "True");
1669    }
1670
1671    #[test]
1672    fn test_bound_predicate_rewrite_not_complex_nested() {
1673        let schema = table_schema_simple();
1674
1675        // Test complex nested expression:
1676        // NOT(NOT((bar >= 10) AND (foo IS NOT NULL)) OR (bar < 5))
1677        // Should become: ((bar >= 10) AND (foo IS NOT NULL)) AND (bar >= 5)
1678        let inner_predicate = Reference::new("bar")
1679            .greater_than_or_equal_to(Datum::int(10))
1680            .and(Reference::new("foo").is_not_null())
1681            .not();
1682
1683        let complex_predicate = inner_predicate
1684            .or(Reference::new("bar").less_than(Datum::int(5)))
1685            .not();
1686
1687        let bound_predicate = complex_predicate.bind(schema.clone(), true).unwrap();
1688        let result = bound_predicate.rewrite_not();
1689
1690        // Expected: ((bar >= 10) AND (foo IS NOT NULL)) AND (bar >= 5)
1691        // This is because NOT(NOT(A) OR B) = A AND NOT(B)
1692        let expected_predicate = Reference::new("bar")
1693            .greater_than_or_equal_to(Datum::int(10))
1694            .and(Reference::new("foo").is_not_null())
1695            .and(Reference::new("bar").greater_than_or_equal_to(Datum::int(5)));
1696
1697        let expected_bound = expected_predicate.bind(schema, true).unwrap();
1698
1699        assert_eq!(result, expected_bound);
1700        assert_eq!(
1701            &format!("{result}"),
1702            "((bar >= 10) AND (foo IS NOT NULL)) AND (bar >= 5)"
1703        );
1704    }
1705}