diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 010a8189b..39b65c2ef 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1421,7 +1421,7 @@ impl fmt::Display for AccessExpr { #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub struct LambdaFunction { /// The parameters to the lambda function. - pub params: OneOrManyWithParens, + pub params: OneOrManyWithParens, /// The body of the lambda function. pub body: Box, /// The syntax style used to write the lambda function. @@ -1446,6 +1446,27 @@ impl fmt::Display for LambdaFunction { } } +/// A parameter to a lambda function, optionally with a data type. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct LambdaFunctionParameter { + /// The name of the parameter + pub name: Ident, + /// The optional data type of the parameter + /// [Snowflake Syntax](https://docs.snowflake.com/en/sql-reference/functions/filter#arguments) + pub data_type: Option, +} + +impl fmt::Display for LambdaFunctionParameter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.data_type { + Some(dt) => write!(f, "{} {}", self.name, dt), + None => write!(f, "{}", self.name), + } + } +} + /// The syntax style for a lambda function. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Copy)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] diff --git a/src/dialect/snowflake.rs b/src/dialect/snowflake.rs index 3b6fa1c29..bc09f1e61 100644 --- a/src/dialect/snowflake.rs +++ b/src/dialect/snowflake.rs @@ -632,6 +632,11 @@ impl Dialect for SnowflakeDialect { fn supports_select_wildcard_rename(&self) -> bool { true } + + /// See + fn supports_lambda_functions(&self) -> bool { + true + } } // Peeks ahead to identify tokens that are expected after diff --git a/src/keywords.rs b/src/keywords.rs index f84f4d213..7db7de850 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -827,6 +827,7 @@ define_keywords!( RECEIVE, RECLUSTER, RECURSIVE, + REDUCE, REF, REFERENCES, REFERENCING, @@ -1051,6 +1052,7 @@ define_keywords!( TRACE, TRAILING, TRANSACTION, + TRANSFORM, TRANSIENT, TRANSLATE, TRANSLATE_REGEX, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 585242a8a..57404a75e 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1603,10 +1603,34 @@ impl<'a> Parser<'a> { value: self.parse_introduced_string_expr()?.into(), }) } + // An unreserved word (likely an identifier) is followed by an arrow, + // which indicates a lambda function with a single, untyped parameter. + // For example: `a -> a * 2`. Token::Arrow if self.dialect.supports_lambda_functions() => { self.expect_token(&Token::Arrow)?; Ok(Expr::Lambda(LambdaFunction { - params: OneOrManyWithParens::One(w.to_ident(w_span)), + params: OneOrManyWithParens::One(LambdaFunctionParameter { + name: w.to_ident(w_span), + data_type: None, + }), + body: Box::new(self.parse_expr()?), + syntax: LambdaSyntax::Arrow, + })) + } + // An unreserved word (likely an identifier) that is followed by another word (likley a data type) + // which is then followed by an arrow, which indicates a lambda function with a single, typed parameter. + // For example: `a INT -> a * 2`. + Token::Word(_) + if self.peek_nth_token_ref(1).token == Token::Arrow + && self.dialect.supports_lambda_functions() => + { + let data_type = self.parse_data_type()?; + self.expect_token(&Token::Arrow)?; + Ok(Expr::Lambda(LambdaFunction { + params: OneOrManyWithParens::One(LambdaFunctionParameter { + name: w.to_ident(w_span), + data_type: Some(data_type), + }), body: Box::new(self.parse_expr()?), syntax: LambdaSyntax::Arrow, })) @@ -2192,7 +2216,7 @@ impl<'a> Parser<'a> { return Ok(None); } self.maybe_parse(|p| { - let params = p.parse_comma_separated(|p| p.parse_identifier())?; + let params = p.parse_comma_separated(|p| p.parse_lambda_function_parameter())?; p.expect_token(&Token::RParen)?; p.expect_token(&Token::Arrow)?; let expr = p.parse_expr()?; @@ -2204,7 +2228,7 @@ impl<'a> Parser<'a> { }) } - /// Parses a lambda expression using the `LAMBDA` keyword syntax. + /// Parses a lambda expression following the `LAMBDA` keyword syntax. /// /// Syntax: `LAMBDA : ` /// @@ -2214,30 +2238,49 @@ impl<'a> Parser<'a> { /// /// See fn parse_lambda_expr(&mut self) -> Result { + // Parse the parameters: either a single identifier or comma-separated identifiers + let params = self.parse_lambda_function_parameters()?; + // Expect the colon separator + self.expect_token(&Token::Colon)?; + // Parse the body expression + let body = self.parse_expr()?; + Ok(Expr::Lambda(LambdaFunction { + params, + body: Box::new(body), + syntax: LambdaSyntax::LambdaKeyword, + })) + } + + /// Parses the parameters of a lambda function with optional typing. + fn parse_lambda_function_parameters( + &mut self, + ) -> Result, ParserError> { // Parse the parameters: either a single identifier or comma-separated identifiers let params = if self.consume_token(&Token::LParen) { // Parenthesized parameters: (x, y) - let params = self.parse_comma_separated(|p| p.parse_identifier())?; + let params = self.parse_comma_separated(|p| p.parse_lambda_function_parameter())?; self.expect_token(&Token::RParen)?; OneOrManyWithParens::Many(params) } else { // Unparenthesized parameters: x or x, y - let params = self.parse_comma_separated(|p| p.parse_identifier())?; + let params = self.parse_comma_separated(|p| p.parse_lambda_function_parameter())?; if params.len() == 1 { OneOrManyWithParens::One(params.into_iter().next().unwrap()) } else { OneOrManyWithParens::Many(params) } }; - // Expect the colon separator - self.expect_token(&Token::Colon)?; - // Parse the body expression - let body = self.parse_expr()?; - Ok(Expr::Lambda(LambdaFunction { - params, - body: Box::new(body), - syntax: LambdaSyntax::LambdaKeyword, - })) + Ok(params) + } + + /// Parses a single parameter of a lambda function, with optional typing. + fn parse_lambda_function_parameter(&mut self) -> Result { + let name = self.parse_identifier()?; + let data_type = match self.peek_token().token { + Token::Word(_) => self.maybe_parse(|p| p.parse_data_type())?, + _ => None, + }; + Ok(LambdaFunctionParameter { name, data_type }) } /// Tries to parse the body of an [ODBC escaping sequence] diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index b6b867049..aafbfb756 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -15872,7 +15872,16 @@ fn test_lambdas() { ] ), Expr::Lambda(LambdaFunction { - params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]), + params: OneOrManyWithParens::Many(vec![ + LambdaFunctionParameter { + name: Ident::new("p1"), + data_type: None + }, + LambdaFunctionParameter { + name: Ident::new("p2"), + data_type: None + } + ]), body: Box::new(Expr::Case { case_token: AttachedToken::empty(), end_token: AttachedToken::empty(), @@ -15917,6 +15926,12 @@ fn test_lambdas() { "map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2))", ); dialects.verified_expr("transform(array(1, 2, 3), x -> x + 1)"); + + // Ensure all lambda variants are parsed correctly + dialects.verified_expr("a -> a * 2"); // Single parameter without type + dialects.verified_expr("a INT -> a * 2"); // Single parameter with type + dialects.verified_expr("(a, b) -> a * b"); // Multiple parameters without types + dialects.verified_expr("(a INT, b FLOAT) -> a * b"); // Multiple parameters with types } #[test] diff --git a/tests/sqlparser_databricks.rs b/tests/sqlparser_databricks.rs index b088afd78..899592c98 100644 --- a/tests/sqlparser_databricks.rs +++ b/tests/sqlparser_databricks.rs @@ -72,7 +72,10 @@ fn test_databricks_exists() { ] ), Expr::Lambda(LambdaFunction { - params: OneOrManyWithParens::One(Ident::new("x")), + params: OneOrManyWithParens::One(LambdaFunctionParameter { + name: Ident::new("x"), + data_type: None + }), body: Box::new(Expr::IsNull(Box::new(Expr::Identifier(Ident::new("x"))))), syntax: LambdaSyntax::Arrow, }) @@ -109,7 +112,16 @@ fn test_databricks_lambdas() { ] ), Expr::Lambda(LambdaFunction { - params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]), + params: OneOrManyWithParens::Many(vec![ + LambdaFunctionParameter { + name: Ident::new("p1"), + data_type: None + }, + LambdaFunctionParameter { + name: Ident::new("p2"), + data_type: None + } + ]), body: Box::new(Expr::Case { case_token: AttachedToken::empty(), end_token: AttachedToken::empty(),