From 9e638dae1661d5fd907433b9a576fd718d0ae70f Mon Sep 17 00:00:00 2001 From: Lach Date: Sat, 11 Oct 2025 20:38:41 +0200 Subject: [PATCH] feat!: strict equality operator Fixes: https://github.com/PRQL/prql/issues/4289 --- grammars/prql-lezer/src/prql.grammar | 4 +- prqlc/prqlc-parser/src/lexer/lr.rs | 15 +++ prqlc/prqlc-parser/src/lexer/mod.rs | 2 + prqlc/prqlc-parser/src/parser/expr.rs | 3 + prqlc/prqlc-parser/src/parser/pr/ops.rs | 6 + prqlc/prqlc/src/cli/highlight.rs | 2 + prqlc/prqlc/src/codegen/ast.rs | 2 + prqlc/prqlc/src/semantic/ast_expand.rs | 8 +- ...tic__resolver__test__frames_and_names.snap | 11 +- ...c__resolver__test__functions_pipeline.snap | 10 +- ...ms__tests__aggregate_positional_arg-2.snap | 110 +++++++++--------- .../src/semantic/resolver/static_eval.rs | 22 +++- prqlc/prqlc/src/semantic/std.prql | 2 + prqlc/prqlc/src/sql/gen_expr.rs | 87 ++++++++++---- prqlc/prqlc/src/sql/pq/preprocess.rs | 6 +- prqlc/prqlc/src/sql/std.sql.prql | 6 + prqlc/prqlc/tests/integration/sql.rs | 12 +- web/book/src/reference/spec/null.md | 10 +- web/book/src/reference/syntax/README.md | 2 +- web/book/src/reference/syntax/operators.md | 2 +- web/playground/src/workbench/prql-syntax.js | 2 + web/website/data/examples/null-handling.yaml | 4 +- 22 files changed, 221 insertions(+), 107 deletions(-) diff --git a/grammars/prql-lezer/src/prql.grammar b/grammars/prql-lezer/src/prql.grammar index b753cb854cbe..495663a65354 100644 --- a/grammars/prql-lezer/src/prql.grammar +++ b/grammars/prql-lezer/src/prql.grammar @@ -62,7 +62,7 @@ testInner { binaryTest | unaryTest | expression } binaryTest[@name="BinaryExpression"] { testInner !or LogicOp<"||" | "??"> testInner | testInner !and LogicOp<"&&"> testInner | - testInner !compare (CompareOp<"==" | "!=" | "~=" | ">=" | "<=" | ">" | "<"> | kw<"in">) testInner + testInner !compare (CompareOp<"==" | "!=" | "~=" | "===" | "!==" | ">=" | "<=" | ">" | "<"> | kw<"in">) testInner } unaryTest[@name="UnaryExpression"] { kw<"!"> testInner } @@ -100,7 +100,7 @@ ParenthesizedExpression { "(" expression ")" } UnaryExpression { !prefix ArithOp<"+" | "-"> expression | - !prefix CompareOp<"=="> Identifier + !prefix CompareOp<"==" | "==="> Identifier } // Because this is outside tokens, we can't disallow whitespace. diff --git a/prqlc/prqlc-parser/src/lexer/lr.rs b/prqlc/prqlc-parser/src/lexer/lr.rs index 93f1f5a8cda2..d0c07bd4d1a4 100644 --- a/prqlc/prqlc-parser/src/lexer/lr.rs +++ b/prqlc/prqlc-parser/src/lexer/lr.rs @@ -42,6 +42,8 @@ pub enum TokenKind { ArrowFat, // => Eq, // == Ne, // != + SEq, // === + SNe, // !== Gte, // >= Lte, // <= RegexSearch, // ~= @@ -90,6 +92,17 @@ pub enum Literal { ValueAndUnit(ValueAndUnit), } +impl Literal { + // FIXME: Should it be a `null_safe_eq`, with `PartialEq` implementing + // 3VAL logic instead? This requires manual `PartialEq` implementation. + pub fn three_val_eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Null, Self::Null) => false, + _ => self == other, + } + } +} + impl TokenKind { pub fn range(bind_left: bool, bind_right: bool) -> Self { TokenKind::Range { @@ -197,6 +210,8 @@ impl std::fmt::Display for TokenKind { TokenKind::ArrowFat => f.write_str("=>"), TokenKind::Eq => f.write_str("=="), TokenKind::Ne => f.write_str("!="), + TokenKind::SEq => f.write_str("==="), + TokenKind::SNe => f.write_str("!=="), TokenKind::Gte => f.write_str(">="), TokenKind::Lte => f.write_str("<="), TokenKind::RegexSearch => f.write_str("~="), diff --git a/prqlc/prqlc-parser/src/lexer/mod.rs b/prqlc/prqlc-parser/src/lexer/mod.rs index e71342bb92cc..a1b4f61e7da6 100644 --- a/prqlc/prqlc-parser/src/lexer/mod.rs +++ b/prqlc/prqlc-parser/src/lexer/mod.rs @@ -80,6 +80,8 @@ fn lex_token() -> impl Parser> { let control_multi = choice(( just("->").to(TokenKind::ArrowThin), just("=>").to(TokenKind::ArrowFat), + just("===").to(TokenKind::SEq), + just("!==").to(TokenKind::SNe), just("==").to(TokenKind::Eq), just("!=").to(TokenKind::Ne), just(">=").to(TokenKind::Gte), diff --git a/prqlc/prqlc-parser/src/parser/expr.rs b/prqlc/prqlc-parser/src/parser/expr.rs index 46a5d4b024e7..967fde5877ab 100644 --- a/prqlc/prqlc-parser/src/parser/expr.rs +++ b/prqlc/prqlc-parser/src/parser/expr.rs @@ -543,6 +543,7 @@ fn operator_unary() -> impl Parser + Clone { .or(ctrl('-').to(UnOp::Neg)) .or(ctrl('!').to(UnOp::Not)) .or(just(TokenKind::Eq).to(UnOp::EqSelf)) + .or(just(TokenKind::SEq).to(UnOp::SEqSelf)) } fn operator_pow() -> impl Parser + Clone { just(TokenKind::Pow).to(BinOp::Pow) @@ -560,6 +561,8 @@ fn operator_compare() -> impl Parser + Clone { choice(( just(TokenKind::Eq).to(BinOp::Eq), just(TokenKind::Ne).to(BinOp::Ne), + just(TokenKind::SEq).to(BinOp::SEq), + just(TokenKind::SNe).to(BinOp::SNe), just(TokenKind::Lte).to(BinOp::Lte), just(TokenKind::Gte).to(BinOp::Gte), just(TokenKind::RegexSearch).to(BinOp::RegexSearch), diff --git a/prqlc/prqlc-parser/src/parser/pr/ops.rs b/prqlc/prqlc-parser/src/parser/pr/ops.rs index 3ac108b9a307..387988b56e66 100644 --- a/prqlc/prqlc-parser/src/parser/pr/ops.rs +++ b/prqlc/prqlc-parser/src/parser/pr/ops.rs @@ -23,6 +23,8 @@ pub enum UnOp { Not, #[strum(to_string = "==")] EqSelf, + #[strum(to_string = "===")] + SEqSelf, } #[derive( @@ -57,6 +59,10 @@ pub enum BinOp { Eq, #[strum(to_string = "!=")] Ne, + #[strum(to_string = "===")] + SEq, + #[strum(to_string = "!==")] + SNe, #[strum(to_string = ">")] Gt, #[strum(to_string = "<")] diff --git a/prqlc/prqlc/src/cli/highlight.rs b/prqlc/prqlc/src/cli/highlight.rs index 096f0de36770..859a94841e8e 100644 --- a/prqlc/prqlc/src/cli/highlight.rs +++ b/prqlc/prqlc/src/cli/highlight.rs @@ -52,6 +52,8 @@ fn highlight_token_kind(token: &TokenKind) -> String { | TokenKind::ArrowFat | TokenKind::Eq | TokenKind::Ne + | TokenKind::SEq + | TokenKind::SNe | TokenKind::Gte | TokenKind::Lte | TokenKind::RegexSearch => output.push_str(&format!("{}", token)), diff --git a/prqlc/prqlc/src/codegen/ast.rs b/prqlc/prqlc/src/codegen/ast.rs index 51517f0ea904..76955ceeb31f 100644 --- a/prqlc/prqlc/src/codegen/ast.rs +++ b/prqlc/prqlc/src/codegen/ast.rs @@ -295,6 +295,8 @@ fn binding_strength(expr: &pr::ExprKind) -> u8 { pr::BinOp::Add | pr::BinOp::Sub => 17, pr::BinOp::Eq | pr::BinOp::Ne + | pr::BinOp::SEq + | pr::BinOp::SNe | pr::BinOp::Gt | pr::BinOp::Lt | pr::BinOp::Gte diff --git a/prqlc/prqlc/src/semantic/ast_expand.rs b/prqlc/prqlc/src/semantic/ast_expand.rs index 8394a41fdbba..c90bbdd28769 100644 --- a/prqlc/prqlc/src/semantic/ast_expand.rs +++ b/prqlc/prqlc/src/semantic/ast_expand.rs @@ -162,7 +162,9 @@ fn expand_unary(pr::UnaryExpr { op, expr }: pr::UnaryExpr) -> Result ["std", "neg"], Not => ["std", "not"], Add => return Ok(expr.kind), - EqSelf => { + EqSelf | SEqSelf => { + let method = if op == EqSelf { "eq" } else { "seq" }; + let pl::ExprKind::Ident(ident) = expr.kind else { return Err(Error::new_simple( "you can only use column names with self-equality operator", @@ -188,7 +190,7 @@ fn expand_unary(pr::UnaryExpr { op, expr }: pr::UnaryExpr) -> Result Result

vec!["std", "sub"], pr::BinOp::Eq => vec!["std", "eq"], pr::BinOp::Ne => vec!["std", "ne"], + pr::BinOp::SEq => vec!["std", "seq"], + pr::BinOp::SNe => vec!["std", "sne"], pr::BinOp::Gt => vec!["std", "gt"], pr::BinOp::Lt => vec!["std", "lt"], pr::BinOp::Gte => vec!["std", "gte"], diff --git a/prqlc/prqlc/src/semantic/resolver/snapshots/prqlc__semantic__resolver__test__frames_and_names.snap b/prqlc/prqlc/src/semantic/resolver/snapshots/prqlc__semantic__resolver__test__frames_and_names.snap index 675179623679..cfe5c694e2fd 100644 --- a/prqlc/prqlc/src/semantic/resolver/snapshots/prqlc__semantic__resolver__test__frames_and_names.snap +++ b/prqlc/prqlc/src/semantic/resolver/snapshots/prqlc__semantic__resolver__test__frames_and_names.snap @@ -7,28 +7,27 @@ columns: name: - orders - customer_no - target_id: 123 + target_id: 125 target_name: ~ - Single: name: - orders - gross - target_id: 124 + target_id: 126 target_name: ~ - Single: name: - orders - tax - target_id: 125 + target_id: 127 target_name: ~ - Single: name: ~ - target_id: 126 + target_id: 128 target_name: ~ inputs: - - id: 121 + - id: 123 name: orders table: - default_db - orders - diff --git a/prqlc/prqlc/src/semantic/resolver/snapshots/prqlc__semantic__resolver__test__functions_pipeline.snap b/prqlc/prqlc/src/semantic/resolver/snapshots/prqlc__semantic__resolver__test__functions_pipeline.snap index 233b068133f6..2a590cf4cee1 100644 --- a/prqlc/prqlc/src/semantic/resolver/snapshots/prqlc__semantic__resolver__test__functions_pipeline.snap +++ b/prqlc/prqlc/src/semantic/resolver/snapshots/prqlc__semantic__resolver__test__functions_pipeline.snap @@ -14,9 +14,9 @@ expression: "resolve_derive(r#\"\n from a\n derive one = ( kind: Array: kind: Any - span: "0:1929-1936" + span: "0:2025-2032" name: ~ - span: "0:1928-1937" + span: "0:2024-2033" name: array span: "1:52-55" alias: one @@ -26,12 +26,12 @@ expression: "resolve_derive(r#\"\n from a\n derive one = ( - - ~ - kind: Primitive: Int - span: "0:4123-4126" + span: "0:4219-4222" name: ~ - - ~ - kind: Primitive: Float - span: "0:4130-4135" + span: "0:4226-4231" name: ~ - span: "0:4123-4135" + span: "0:4219-4231" name: ~ diff --git a/prqlc/prqlc/src/semantic/resolver/snapshots/prqlc__semantic__resolver__transforms__tests__aggregate_positional_arg-2.snap b/prqlc/prqlc/src/semantic/resolver/snapshots/prqlc__semantic__resolver__transforms__tests__aggregate_positional_arg-2.snap index 75e5d7b100b4..347841450e1d 100644 --- a/prqlc/prqlc/src/semantic/resolver/snapshots/prqlc__semantic__resolver__transforms__tests__aggregate_positional_arg-2.snap +++ b/prqlc/prqlc/src/semantic/resolver/snapshots/prqlc__semantic__resolver__transforms__tests__aggregate_positional_arg-2.snap @@ -15,19 +15,19 @@ TransformCall: Tuple: - Wildcard: kind: Any - span: "0:2045-2052" + span: "0:2141-2148" name: ~ - span: "0:2042-2053" + span: "0:2138-2149" name: tuple - span: "0:2158-2165" + span: "0:2254-2261" name: relation lineage: columns: - All: - input_id: 118 + input_id: 120 except: [] inputs: - - id: 118 + - id: 120 name: c_invoice table: - default_db @@ -48,9 +48,9 @@ TransformCall: kind: Array: kind: Any - span: "0:1929-1936" + span: "0:2025-2032" name: ~ - span: "0:1928-1937" + span: "0:2024-2033" name: array span: "1:73-87" ty: @@ -59,14 +59,14 @@ TransformCall: - - ~ - kind: Primitive: Float - span: "0:4188-4193" + span: "0:4284-4289" name: ~ - - ~ - kind: Singleton: "Null" - span: "0:4197-4201" + span: "0:4293-4297" name: ~ - span: "0:4188-4201" + span: "0:4284-4297" name: ~ span: "1:73-87" ty: @@ -79,14 +79,14 @@ TransformCall: - - ~ - kind: Primitive: Float - span: "0:4188-4193" + span: "0:4284-4289" name: ~ - - ~ - kind: Singleton: "Null" - span: "0:4197-4201" + span: "0:4293-4297" name: ~ - span: "0:4188-4201" + span: "0:4284-4297" name: ~ span: ~ name: ~ @@ -106,55 +106,55 @@ TransformCall: - - ~ - kind: Primitive: Int - span: "0:1963-1966" + span: "0:2059-2062" name: ~ - - ~ - kind: Primitive: Float - span: "0:1970-1975" + span: "0:2066-2071" name: ~ - - ~ - kind: Primitive: Bool - span: "0:1979-1983" + span: "0:2075-2079" name: ~ - - ~ - kind: Primitive: Text - span: "0:1987-1991" + span: "0:2083-2087" name: ~ - - ~ - kind: Primitive: Date - span: "0:1995-1999" + span: "0:2091-2095" name: ~ - - ~ - kind: Primitive: Time - span: "0:2003-2007" + span: "0:2099-2103" name: ~ - - ~ - kind: Primitive: Timestamp - span: "0:2011-2020" + span: "0:2107-2116" name: ~ - - ~ - kind: Singleton: "Null" - span: "0:2024-2028" + span: "0:2120-2124" name: ~ - span: "0:1963-2028" + span: "0:2059-2124" name: scalar - - ~ - kind: Tuple: - Wildcard: kind: Any - span: "0:2045-2052" + span: "0:2141-2148" name: ~ - span: "0:2042-2053" + span: "0:2138-2149" name: tuple - span: "0:3013-3028" + span: "0:3109-3124" name: ~ span: "1:38-47" ty: @@ -170,55 +170,55 @@ TransformCall: - - ~ - kind: Primitive: Int - span: "0:1963-1966" + span: "0:2059-2062" name: ~ - - ~ - kind: Primitive: Float - span: "0:1970-1975" + span: "0:2066-2071" name: ~ - - ~ - kind: Primitive: Bool - span: "0:1979-1983" + span: "0:2075-2079" name: ~ - - ~ - kind: Primitive: Text - span: "0:1987-1991" + span: "0:2083-2087" name: ~ - - ~ - kind: Primitive: Date - span: "0:1995-1999" + span: "0:2091-2095" name: ~ - - ~ - kind: Primitive: Time - span: "0:2003-2007" + span: "0:2099-2103" name: ~ - - ~ - kind: Primitive: Timestamp - span: "0:2011-2020" + span: "0:2107-2116" name: ~ - - ~ - kind: Singleton: "Null" - span: "0:2024-2028" + span: "0:2120-2124" name: ~ - span: "0:1963-2028" + span: "0:2059-2124" name: scalar - - ~ - kind: Tuple: - Wildcard: kind: Any - span: "0:2045-2052" + span: "0:2141-2148" name: ~ - span: "0:2042-2053" + span: "0:2138-2149" name: tuple - span: "0:3013-3028" + span: "0:3109-3124" name: ~ span: ~ name: ~ @@ -238,55 +238,55 @@ ty: - - ~ - kind: Primitive: Int - span: "0:1963-1966" + span: "0:2059-2062" name: ~ - - ~ - kind: Primitive: Float - span: "0:1970-1975" + span: "0:2066-2071" name: ~ - - ~ - kind: Primitive: Bool - span: "0:1979-1983" + span: "0:2075-2079" name: ~ - - ~ - kind: Primitive: Text - span: "0:1987-1991" + span: "0:2083-2087" name: ~ - - ~ - kind: Primitive: Date - span: "0:1995-1999" + span: "0:2091-2095" name: ~ - - ~ - kind: Primitive: Time - span: "0:2003-2007" + span: "0:2099-2103" name: ~ - - ~ - kind: Primitive: Timestamp - span: "0:2011-2020" + span: "0:2107-2116" name: ~ - - ~ - kind: Singleton: "Null" - span: "0:2024-2028" + span: "0:2120-2124" name: ~ - span: "0:1963-2028" + span: "0:2059-2124" name: scalar - - ~ - kind: Tuple: - Wildcard: kind: Any - span: "0:2045-2052" + span: "0:2141-2148" name: ~ - span: "0:2042-2053" + span: "0:2138-2149" name: tuple - span: "0:3013-3028" + span: "0:3109-3124" name: ~ - Single: - ~ @@ -295,14 +295,14 @@ ty: - - ~ - kind: Primitive: Float - span: "0:4188-4193" + span: "0:4284-4289" name: ~ - - ~ - kind: Singleton: "Null" - span: "0:4197-4201" + span: "0:4293-4297" name: ~ - span: "0:4188-4201" + span: "0:4284-4297" name: ~ span: ~ name: ~ @@ -314,14 +314,14 @@ lineage: name: - c_invoice - issued_at - target_id: 120 + target_id: 122 target_name: ~ - Single: name: ~ - target_id: 136 + target_id: 138 target_name: ~ inputs: - - id: 118 + - id: 120 name: c_invoice table: - default_db diff --git a/prqlc/prqlc/src/semantic/resolver/static_eval.rs b/prqlc/prqlc/src/semantic/resolver/static_eval.rs index 9d579b109c6d..f918894db5b2 100644 --- a/prqlc/prqlc/src/semantic/resolver/static_eval.rs +++ b/prqlc/prqlc/src/semantic/resolver/static_eval.rs @@ -49,11 +49,31 @@ fn static_eval_rq_operator(mut expr: Expr) -> Expr { { // don't eval comparisons between different types of literals if left.as_ref() == right.as_ref() { - return Expr::new(Literal::Boolean(left == right)); + return Expr::new(Literal::Boolean(left.three_val_eq(right))); } } } "std.ne" => { + if let (ExprKind::Literal(left), ExprKind::Literal(right)) = + (&args[0].kind, &args[1].kind) + { + // don't eval comparisons between different types of literals + if left.as_ref() == right.as_ref() { + return Expr::new(Literal::Boolean(left.three_val_eq(right))); + } + } + } + "std.seq" => { + if let (ExprKind::Literal(left), ExprKind::Literal(right)) = + (&args[0].kind, &args[1].kind) + { + // don't eval comparisons between different types of literals + if left.as_ref() == right.as_ref() { + return Expr::new(Literal::Boolean(left == right)); + } + } + } + "std.sne" => { if let (ExprKind::Literal(left), ExprKind::Literal(right)) = (&args[0].kind, &args[1].kind) { diff --git a/prqlc/prqlc/src/semantic/std.prql b/prqlc/prqlc/src/semantic/std.prql index 2c60440de567..20d14863a65d 100644 --- a/prqlc/prqlc/src/semantic/std.prql +++ b/prqlc/prqlc/src/semantic/std.prql @@ -23,6 +23,8 @@ let add = left right -> internal std.add let sub = left right -> internal std.sub let eq = left right -> internal std.eq let ne = left right -> internal std.ne +let seq = left right -> internal std.seq +let sne = left right -> internal std.sne let gt = left right -> internal std.gt let lt = left right -> internal std.lt let gte = left right -> internal std.gte diff --git a/prqlc/prqlc/src/sql/gen_expr.rs b/prqlc/prqlc/src/sql/gen_expr.rs index d7aeb94c06b9..d2ac95127da0 100644 --- a/prqlc/prqlc/src/sql/gen_expr.rs +++ b/prqlc/prqlc/src/sql/gen_expr.rs @@ -90,7 +90,7 @@ pub(super) fn translate_expr(expr: rq::Expr, ctx: &mut Context) -> Result { + "std.eq" | "std.ne" | "std.seq" | "std.sne" => { if let [a, b] = args.as_slice() { if a.kind == rq::ExprKind::Literal(Literal::Null) || b.kind == rq::ExprKind::Literal(Literal::Null) @@ -139,13 +139,23 @@ fn process_null(name: &str, args: &[rq::Expr], ctx: &mut Context) -> Result Result { fn translate_binary_operator( left: &rq::Expr, right: &rq::Expr, - op: BinaryOperator, + op: BinaryOperatorEx, ctx: &mut Context, ) -> Result { let strength = op.binding_strength(); @@ -298,7 +308,11 @@ fn translate_binary_operator( let left = Box::new(left.into_ast()); let right = Box::new(right.into_ast()); - Ok(sql_ast::Expr::BinaryOp { left, op, right }) + Ok(match op { + BinaryOperatorEx::Plain(op) => sql_ast::Expr::BinaryOp { left, op, right }, + BinaryOperatorEx::SEq => sql_ast::Expr::IsNotDistinctFrom(left, right), + BinaryOperatorEx::SNotEq => sql_ast::Expr::IsDistinctFrom(left, right), + }) } fn collect_concat_args(expr: &rq::Expr) -> Vec<&rq::Expr> { @@ -358,23 +372,25 @@ fn try_into_between(expr: rq::Expr, ctx: &mut Context) -> Result Option { +fn operator_from_name(name: &str) -> Option { use BinaryOperator::*; - match name { - "std.mul" => Some(Multiply), - "std.add" => Some(Plus), - "std.sub" => Some(Minus), - "std.eq" => Some(Eq), - "std.ne" => Some(NotEq), - "std.gt" => Some(Gt), - "std.lt" => Some(Lt), - "std.gte" => Some(GtEq), - "std.lte" => Some(LtEq), - "std.and" => Some(And), - "std.or" => Some(Or), - "std.concat" => Some(StringConcat), - _ => None, - } + Some(BinaryOperatorEx::Plain(match name { + "std.mul" => Multiply, + "std.add" => Plus, + "std.sub" => Minus, + "std.eq" => Eq, + "std.ne" => NotEq, + "std.gt" => Gt, + "std.lt" => Lt, + "std.gte" => GtEq, + "std.lte" => LtEq, + "std.and" => And, + "std.or" => Or, + "std.concat" => StringConcat, + "std.seq" => return Some(BinaryOperatorEx::SEq), + "std.sne" => return Some(BinaryOperatorEx::SNotEq), + _ => return None, + })) } pub(super) fn translate_literal(l: Literal, ctx: &Context) -> Result { @@ -929,6 +945,11 @@ impl SQLExpression for sql_ast::Expr { sql_ast::Expr::Like { .. } | sql_ast::Expr::ILike { .. } => 7, + // FIXME: Check that there is no operators with higher/lower binding power. + // 8 should work for prql right now, but I'm not sure about future, shouldn't all + // binding power constants be enumerated in one place? + sql_ast::Expr::IsDistinctFrom(_, _) | sql_ast::Expr::IsNotDistinctFrom(_, _) => 8, + sql_ast::Expr::IsNull(_) | sql_ast::Expr::IsNotNull(_) => 5, // all other items types bind stronger (function calls, literals, ...) @@ -977,6 +998,30 @@ impl SQLExpression for UnaryOperator { } } +/// In sqlparser crate, `IS [NOT] DISTINCT FROM` are defined +/// as a separate expression kind instead of `BinaryOperator`. +/// FIXME: Maybe it needs to be refactored upstream? +enum BinaryOperatorEx { + Plain(BinaryOperator), + SEq, + SNotEq, +} +impl SQLExpression for BinaryOperatorEx { + fn binding_strength(&self) -> i32 { + match self { + BinaryOperatorEx::Plain(plain) => plain.binding_strength(), + BinaryOperatorEx::SEq | BinaryOperatorEx::SNotEq => 8, + } + } + + fn associativity(&self) -> Associativity { + match self { + BinaryOperatorEx::Plain(plain) => plain.associativity(), + BinaryOperatorEx::SEq | BinaryOperatorEx::SNotEq => Associativity::Both, + } + } +} + /// A wrapper around sql_ast::Expr, that may have already been converted to source. #[derive(Debug, Clone)] pub enum ExprOrSource { diff --git a/prqlc/prqlc/src/sql/pq/preprocess.rs b/prqlc/prqlc/src/sql/pq/preprocess.rs index a4b98948bf72..7a6b4ba79c9a 100644 --- a/prqlc/prqlc/src/sql/pq/preprocess.rs +++ b/prqlc/prqlc/src/sql/pq/preprocess.rs @@ -507,6 +507,8 @@ fn all_null(exprs: Vec<&Expr>) -> bool { /// Converts `(a == b) and ((c == d) and (e == f))` /// into `([a, c, e], [b, d, f])` +// FIXME: Should similar convertation happen for `===`? It is used for joins, and I don't know about `===` applicability +// here. fn collect_equals(expr: &Expr) -> Result<(Vec<&Expr>, Vec<&Expr>)> { let mut lefts = Vec::new(); let mut rights = Vec::new(); @@ -594,7 +596,9 @@ impl RqFold for Normalizer { }; if let ExprKind::Operator { name, args } = &expr.kind { - if name == "std.eq" && args.len() == 2 { + // Reorder arguments, null === a => a === null + // FIXME: Shouldn't it also process !==? + if name == "std.seq" && args.len() == 2 { let (left, right) = (&args[0], &args[1]); let span = expr.span; let new_args = if let ExprKind::Literal(Literal::Null) = &left.kind { diff --git a/prqlc/prqlc/src/sql/std.sql.prql b/prqlc/prqlc/src/sql/std.sql.prql index e2bfada1bfca..f33372146eeb 100644 --- a/prqlc/prqlc/src/sql/std.sql.prql +++ b/prqlc/prqlc/src/sql/std.sql.prql @@ -154,6 +154,12 @@ let eq = l r -> null @{binding_strength=6} let ne = l r -> null +@{binding_strength=6} +let seq = l r -> null + +@{binding_strength=6} +let sne = l r -> null + @{binding_strength=6} let gt = l r -> null diff --git a/prqlc/prqlc/tests/integration/sql.rs b/prqlc/prqlc/tests/integration/sql.rs index 3821803fe638..b3fa5e98f1cb 100644 --- a/prqlc/prqlc/tests/integration/sql.rs +++ b/prqlc/prqlc/tests/integration/sql.rs @@ -357,10 +357,10 @@ fn test_precedence_04() { zero = !gtz && !ltz, is_not_equal = !(a==b), is_not_gt = !(a>b), - negated_is_null_1 = !a == null, - negated_is_null_2 = (!a) == null, - is_not_null = !(a == null), - (a + b) == null, + negated_is_null_1 = !a === null, + negated_is_null_2 = (!a) === null, + is_not_null = !(a === null), + (a + b) === null, } "###).unwrap()), @r###" SELECT @@ -1824,7 +1824,7 @@ fn test_nulls_03() { // IS NULL assert_snapshot!((compile(r###" from employees - filter first_name == null && null == last_name + filter first_name === null && null === last_name "###).unwrap()), @r###" SELECT * @@ -1841,7 +1841,7 @@ fn test_nulls_04() { // IS NOT NULL assert_snapshot!((compile(r###" from employees - filter first_name != null && null != last_name + filter first_name !== null && null !== last_name "###).unwrap()), @r###" SELECT * diff --git a/web/book/src/reference/spec/null.md b/web/book/src/reference/spec/null.md index cee036de1440..27326ea87d82 100644 --- a/web/book/src/reference/spec/null.md +++ b/web/book/src/reference/spec/null.md @@ -1,5 +1,7 @@ # Null handling +FIXME: Update for null-safe equality operator + SQL has an unconventional way of handling `NULL` values, since it treats them as unknown values. As a result, in SQL: @@ -17,14 +19,14 @@ For more information, check out the PRQL, on the other hand, treats `null` as a value, which means that: -- `null == null` evaluates to `true`, -- `null != null` evaluates to `false`, +- `null === null` evaluates to `true`, +- `null !== null` evaluates to `false`, - distinct column cannot contain multiple `null` values. ```prql from employees -filter first_name == null -filter null != last_name +filter first_name === null +filter null !== last_name ``` Note that PRQL doesn't change how `NULL` is compared between columns, for diff --git a/web/book/src/reference/syntax/README.md b/web/book/src/reference/syntax/README.md index dd3e332b39c8..8fc5b0c254dc 100644 --- a/web/book/src/reference/syntax/README.md +++ b/web/book/src/reference/syntax/README.md @@ -15,7 +15,7 @@ A summary of PRQL syntax: | `:` | [Named args & parameters](../declarations/functions.md) | `interp low:0 1600 sat_score` | | `{}` | [Tuples](./tuples.md) | `{id, false, total = 3}` | | `[]` | [Arrays](./arrays.md) | `[1, 4, 3, 4]` | -| `+`,`!`,`&&`,`==`, etc | [Operators](./operators.md) | filter a == b + c \|\| d >= e | +| `+`,`!`,`&&`,`==`,`===` etc | [Operators](./operators.md) | filter a == b + c \|\| d >= e | | `()` | [Parentheses](./operators.md#parentheses) | `derive celsius = (fht - 32) / 1.8` | | `\` | [Line wrap](./operators.md#wrapping-lines) | 1 + 2 + 3 +
\ 4 + 5 | | `1`,`100_000`,`5e10` | [Numbers](./literals.md#numbers) | `derive { huge = 5e10 * 10_000 }` | diff --git a/web/book/src/reference/syntax/operators.md b/web/book/src/reference/syntax/operators.md index e10a86becbe5..2b97f1dcfe52 100644 --- a/web/book/src/reference/syntax/operators.md +++ b/web/book/src/reference/syntax/operators.md @@ -30,7 +30,7 @@ operations and for function calls (see the discussion below.) | pow | `**` | 4 | right-to-left | | mul | `*` `/` `//` `%` | 5 | left-to-right | | add | `+` `-` | 6 | left-to-right | -| compare | `==` `!=` `<=` `>=` `<` `>` | 7 | left-to-right | +| compare | `==` `!=` `===` `!==` `<=` `>=` `<` `>` | 7 | left-to-right | | coalesce | `??` | 8 | left-to-right | | and | `&&` | 9 | left-to-right | | or | \|\| | 10 | left-to-right | diff --git a/web/playground/src/workbench/prql-syntax.js b/web/playground/src/workbench/prql-syntax.js index d5ff826df044..88befa0b667f 100644 --- a/web/playground/src/workbench/prql-syntax.js +++ b/web/playground/src/workbench/prql-syntax.js @@ -40,6 +40,8 @@ const def = { // "**", "==", "!=", + "===", + "!==", "->", "=>", ">", diff --git a/web/website/data/examples/null-handling.yaml b/web/website/data/examples/null-handling.yaml index 38bab1ea3c44..bc9fc2b2a866 100644 --- a/web/website/data/examples/null-handling.yaml +++ b/web/website/data/examples/null-handling.yaml @@ -1,8 +1,8 @@ label: Null handling prql: | from users - filter last_login != null - filter deleted_at == null + filter last_login !== null + filter deleted_at === null derive channel = channel ?? "unknown" sql: | SELECT