Skip to content

Commit

Permalink
fix: display expr considering associativity (#15685)
Browse files Browse the repository at this point in the history
* fix: display expr considering associativity

* fix
  • Loading branch information
andylokandy committed May 30, 2024
1 parent 195a552 commit c0b233e
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 28 deletions.
82 changes: 55 additions & 27 deletions src/query/ast/src/ast/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use derive_visitor::DriveMut;
use enum_as_inner::EnumAsInner;
use ethnum::i256;
use pratt::Affix;
use pratt::Precedence;
use pratt::Associativity;

use super::ColumnRef;
use super::OrderByExpr;
Expand Down Expand Up @@ -415,26 +415,54 @@ impl Expr {
]
}

pub fn precedence(&self) -> Option<Precedence> {
match ExprElement::from(self.clone()).affix() {
Affix::Nilfix => None,
Affix::Infix(p, _) => Some(p),
Affix::Prefix(p) => Some(p),
Affix::Postfix(p) => Some(p),
}
fn affix(&self) -> Affix {
ExprElement::from(self.clone()).affix()
}
}

impl Display for Expr {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
fn needs_parentheses(parent: Option<Affix>, child: Affix, is_left: bool) -> bool {
match (parent, child) {
(Some(Affix::Infix(parent_prec, parent_assoc)), Affix::Infix(child_prec, _)) => {
if parent_prec < child_prec {
return false;
}
if parent_prec > child_prec {
return true;
}
if matches!(parent_assoc, Associativity::Left) && !is_left {
return true;
}
if matches!(parent_assoc, Associativity::Right) && is_left {
return true;
}
}
(
Some(
Affix::Infix(parent_prec, _)
| Affix::Prefix(parent_prec)
| Affix::Postfix(parent_prec),
),
Affix::Infix(child_prec, _)
| Affix::Prefix(child_prec)
| Affix::Postfix(child_prec),
) => {
return parent_prec > child_prec;
}
_ => (),
}
false
}

fn write_expr(
expr: &Expr,
min_precedence: Precedence,
parent: Option<Affix>,
is_left: bool,
f: &mut Formatter,
) -> std::fmt::Result {
let prec = expr.precedence();
let need_paren = prec.map(|p| p < min_precedence).unwrap_or(false);
let min_prec = prec.unwrap_or(Precedence(0));
let affix = expr.affix();
let need_paren = needs_parentheses(parent, affix, is_left);

if need_paren {
write!(f, "(")?;
Expand All @@ -449,7 +477,7 @@ impl Display for Expr {
}
}
Expr::IsNull { expr, not, .. } => {
write_expr(expr, min_prec, f)?;
write_expr(expr, Some(affix), true, f)?;
write!(f, " IS")?;
if *not {
write!(f, " NOT")?;
Expand All @@ -459,19 +487,19 @@ impl Display for Expr {
Expr::IsDistinctFrom {
left, right, not, ..
} => {
write_expr(left, min_prec, f)?;
write_expr(left, Some(affix), true, f)?;
write!(f, " IS")?;
if *not {
write!(f, " NOT")?;
}
write!(f, " DISTINCT FROM ")?;
write_expr(right, min_prec, f)?;
write_expr(right, Some(affix), true, f)?;
}

Expr::InList {
expr, list, not, ..
} => {
write_expr(expr, min_prec, f)?;
write_expr(expr, Some(affix), true, f)?;
if *not {
write!(f, " NOT")?;
}
Expand All @@ -485,7 +513,7 @@ impl Display for Expr {
not,
..
} => {
write_expr(expr, min_prec, f)?;
write_expr(expr, Some(affix), true, f)?;
if *not {
write!(f, " NOT")?;
}
Expand All @@ -498,7 +526,7 @@ impl Display for Expr {
not,
..
} => {
write_expr(expr, min_prec, f)?;
write_expr(expr, Some(affix), true, f)?;
if *not {
write!(f, " NOT")?;
}
Expand All @@ -508,28 +536,28 @@ impl Display for Expr {
match op {
// TODO (xieqijun) Maybe special attribute are provided to check whether the symbol is before or after.
UnaryOperator::Factorial => {
write_expr(expr, min_prec, f)?;
write_expr(expr, Some(affix), true, f)?;
write!(f, " {op}")?;
}
_ => {
write!(f, "{op} ")?;
write_expr(expr, min_prec, f)?;
write_expr(expr, Some(affix), true, f)?;
}
}
}
Expr::BinaryOp {
op, left, right, ..
} => {
write_expr(left, min_prec, f)?;
write_expr(left, Some(affix), true, f)?;
write!(f, " {op} ")?;
write_expr(right, min_prec, f)?;
write_expr(right, Some(affix), false, f)?;
}
Expr::JsonOp {
op, left, right, ..
} => {
write_expr(left, min_prec, f)?;
write_expr(left, Some(affix), true, f)?;
write!(f, " {op} ")?;
write_expr(right, min_prec, f)?;
write_expr(right, Some(affix), true, f)?;
}
Expr::Cast {
expr,
Expand All @@ -538,7 +566,7 @@ impl Display for Expr {
..
} => {
if *pg_style {
write_expr(expr, min_prec, f)?;
write_expr(expr, Some(affix), true, f)?;
write!(f, "::{target_type}")?;
} else {
write!(f, "CAST({expr} AS {target_type})")?;
Expand Down Expand Up @@ -641,7 +669,7 @@ impl Display for Expr {
write!(f, "({subquery})")?;
}
Expr::MapAccess { expr, accessor, .. } => {
write_expr(expr, min_prec, f)?;
write_expr(expr, Some(affix), true, f)?;
match accessor {
MapAccessor::Bracket { key } => write!(f, "[{key}]")?,
MapAccessor::DotNumber { key } => write!(f, ".{key}")?,
Expand Down Expand Up @@ -697,7 +725,7 @@ impl Display for Expr {
Ok(())
}

write_expr(self, Precedence(0), f)
write_expr(self, None, true, f)
}
}

Expand Down
1 change: 1 addition & 0 deletions src/query/ast/tests/it/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ fn test_statement() {
r#"insert into table t select * from t2;"#,
r#"select parse_json('{"k1": [0, 1, 2]}').k1[0];"#,
r#"SELECT avg((number > 314)::UInt32);"#,
r#"SELECT 1 - (2 + 3);"#,
r#"CREATE STAGE ~"#,
r#"CREATE STAGE IF NOT EXISTS test_stage 's3://load/files/' credentials=(aws_key_id='1a2b3c', aws_secret_key='4x5y6z') file_format=(type = CSV, compression = GZIP record_delimiter=',')"#,
r#"CREATE STAGE IF NOT EXISTS test_stage url='s3://load/files/' credentials=(aws_key_id='1a2b3c', aws_secret_key='4x5y6z') file_format=(type = CSV, compression = GZIP record_delimiter=',')"#,
Expand Down
76 changes: 76 additions & 0 deletions src/query/ast/tests/it/testdata/statement.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10139,6 +10139,82 @@ Query(
)


---------- Input ----------
SELECT 1 - (2 + 3);
---------- Output ---------
SELECT 1 - (2 + 3)
---------- AST ------------
Query(
Query {
span: Some(
0..18,
),
with: None,
body: Select(
SelectStmt {
span: Some(
0..18,
),
hints: None,
distinct: false,
top_n: None,
select_list: [
AliasedExpr {
expr: BinaryOp {
span: Some(
9..10,
),
op: Minus,
left: Literal {
span: Some(
7..8,
),
value: UInt64(
1,
),
},
right: BinaryOp {
span: Some(
14..15,
),
op: Plus,
left: Literal {
span: Some(
12..13,
),
value: UInt64(
2,
),
},
right: Literal {
span: Some(
16..17,
),
value: UInt64(
3,
),
},
},
},
alias: None,
},
],
from: [],
selection: None,
group_by: None,
having: None,
window_list: None,
qualify: None,
},
),
order_by: [],
limit: [],
offset: None,
ignore_result: false,
},
)


---------- Input ----------
CREATE STAGE ~
---------- Output ---------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ SELECT definition FROM SYSTEM.USER_FUNCTIONS ORDER BY name;
(Float32 NULL, Float64 NULL) RETURNS Float64 NULL LANGUAGE python HANDLER = add_float ADDRESS = http://0.0.0.0:8815
(Int8 NULL, Int16 NULL, Int32 NULL, Int64 NULL) RETURNS Int64 NULL LANGUAGE python HANDLER = add_signed ADDRESS = http://0.0.0.0:8815
(UInt8 NULL, UInt16 NULL, UInt32 NULL, UInt64 NULL) RETURNS UInt64 NULL LANGUAGE python HANDLER = add_unsigned ADDRESS = http://0.0.0.0:8815
(a, b, c, d, e) -> a + c * e / b - d
(a, b, c, d, e) -> a + c * (e / b) - d
(p) -> NOT is_null(p)

# DROP FUNCTIONS
Expand Down

0 comments on commit c0b233e

Please sign in to comment.