Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1421,7 +1421,7 @@ impl fmt::Display for AccessExpr {
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct LambdaFunction {
/// The parameters to the lambda function.
pub params: OneOrManyWithParens<Ident>,
pub params: OneOrManyWithParens<LambdaFunctionParameter>,
/// The body of the lambda function.
pub body: Box<Expr>,
/// The syntax style used to write the lambda function.
Expand All @@ -1446,6 +1446,27 @@ impl fmt::Display for LambdaFunction {
}
}

/// A parameter to a lambda function, optionally with a data type.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct LambdaFunctionParameter {
/// The name of the parameter
pub name: Ident,
/// The optional data type of the parameter
/// [Snowflake Syntax](https://docs.snowflake.com/en/sql-reference/functions/filter#arguments)
pub data_type: Option<DataType>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a link to the docs containing the data_type syntax?

}

impl fmt::Display for LambdaFunctionParameter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.data_type {
Some(dt) => write!(f, "{} {}", self.name, dt),
None => write!(f, "{}", self.name),
}
}
}

/// The syntax style for a lambda function.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Copy)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand Down
5 changes: 5 additions & 0 deletions src/dialect/snowflake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,11 @@ impl Dialect for SnowflakeDialect {
fn supports_select_wildcard_rename(&self) -> bool {
true
}

/// See <https://docs.snowflake.com/en/user-guide/querying-semistructured#label-higher-order-functions>
fn supports_lambda_functions(&self) -> bool {
true
}
}

// Peeks ahead to identify tokens that are expected after
Expand Down
2 changes: 2 additions & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,7 @@ define_keywords!(
RECEIVE,
RECLUSTER,
RECURSIVE,
REDUCE,
REF,
REFERENCES,
REFERENCING,
Expand Down Expand Up @@ -1051,6 +1052,7 @@ define_keywords!(
TRACE,
TRAILING,
TRANSACTION,
TRANSFORM,
TRANSIENT,
TRANSLATE,
TRANSLATE_REGEX,
Expand Down
71 changes: 57 additions & 14 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1603,10 +1603,34 @@ impl<'a> Parser<'a> {
value: self.parse_introduced_string_expr()?.into(),
})
}
// An unreserved word (likely an identifier) is followed by an arrow,
// which indicates a lambda function with a single, untyped parameter.
// For example: `a -> a * 2`.
Token::Arrow if self.dialect.supports_lambda_functions() => {
self.expect_token(&Token::Arrow)?;
Ok(Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::One(w.to_ident(w_span)),
params: OneOrManyWithParens::One(LambdaFunctionParameter {
name: w.to_ident(w_span),
data_type: None,
}),
body: Box::new(self.parse_expr()?),
syntax: LambdaSyntax::Arrow,
}))
}
// An unreserved word (likely an identifier) that is followed by another word (likley a data type)
// which is then followed by an arrow, which indicates a lambda function with a single, typed parameter.
// For example: `a INT -> a * 2`.
Token::Word(_)
if self.peek_nth_token_ref(1).token == Token::Arrow
&& self.dialect.supports_lambda_functions() =>
{
let data_type = self.parse_data_type()?;
self.expect_token(&Token::Arrow)?;
Ok(Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::One(LambdaFunctionParameter {
name: w.to_ident(w_span),
data_type: Some(data_type),
}),
body: Box::new(self.parse_expr()?),
syntax: LambdaSyntax::Arrow,
}))
Expand Down Expand Up @@ -2192,7 +2216,7 @@ impl<'a> Parser<'a> {
return Ok(None);
}
self.maybe_parse(|p| {
let params = p.parse_comma_separated(|p| p.parse_identifier())?;
let params = p.parse_comma_separated(|p| p.parse_lambda_function_parameter())?;
p.expect_token(&Token::RParen)?;
p.expect_token(&Token::Arrow)?;
let expr = p.parse_expr()?;
Expand All @@ -2204,7 +2228,7 @@ impl<'a> Parser<'a> {
})
}

/// Parses a lambda expression using the `LAMBDA` keyword syntax.
/// Parses a lambda expression following the `LAMBDA` keyword syntax.
///
/// Syntax: `LAMBDA <params> : <expr>`
///
Expand All @@ -2214,30 +2238,49 @@ impl<'a> Parser<'a> {
///
/// See <https://duckdb.org/docs/stable/sql/functions/lambda>
fn parse_lambda_expr(&mut self) -> Result<Expr, ParserError> {
// Parse the parameters: either a single identifier or comma-separated identifiers
let params = self.parse_lambda_function_parameters()?;
// Expect the colon separator
self.expect_token(&Token::Colon)?;
// Parse the body expression
let body = self.parse_expr()?;
Ok(Expr::Lambda(LambdaFunction {
params,
body: Box::new(body),
syntax: LambdaSyntax::LambdaKeyword,
}))
}

/// Parses the parameters of a lambda function with optional typing.
fn parse_lambda_function_parameters(
&mut self,
) -> Result<OneOrManyWithParens<LambdaFunctionParameter>, ParserError> {
// Parse the parameters: either a single identifier or comma-separated identifiers
let params = if self.consume_token(&Token::LParen) {
// Parenthesized parameters: (x, y)
let params = self.parse_comma_separated(|p| p.parse_identifier())?;
let params = self.parse_comma_separated(|p| p.parse_lambda_function_parameter())?;
self.expect_token(&Token::RParen)?;
OneOrManyWithParens::Many(params)
} else {
// Unparenthesized parameters: x or x, y
let params = self.parse_comma_separated(|p| p.parse_identifier())?;
let params = self.parse_comma_separated(|p| p.parse_lambda_function_parameter())?;
if params.len() == 1 {
OneOrManyWithParens::One(params.into_iter().next().unwrap())
} else {
OneOrManyWithParens::Many(params)
}
};
// Expect the colon separator
self.expect_token(&Token::Colon)?;
// Parse the body expression
let body = self.parse_expr()?;
Ok(Expr::Lambda(LambdaFunction {
params,
body: Box::new(body),
syntax: LambdaSyntax::LambdaKeyword,
}))
Ok(params)
}

/// Parses a single parameter of a lambda function, with optional typing.
fn parse_lambda_function_parameter(&mut self) -> Result<LambdaFunctionParameter, ParserError> {
let name = self.parse_identifier()?;
let data_type = match self.peek_token().token {
Token::Word(_) => self.maybe_parse(|p| p.parse_data_type())?,
_ => None,
};
Ok(LambdaFunctionParameter { name, data_type })
}

/// Tries to parse the body of an [ODBC escaping sequence]
Expand Down
17 changes: 16 additions & 1 deletion tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15872,7 +15872,16 @@ fn test_lambdas() {
]
),
Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
params: OneOrManyWithParens::Many(vec![
LambdaFunctionParameter {
name: Ident::new("p1"),
data_type: None
},
LambdaFunctionParameter {
name: Ident::new("p2"),
data_type: None
}
]),
body: Box::new(Expr::Case {
case_token: AttachedToken::empty(),
end_token: AttachedToken::empty(),
Expand Down Expand Up @@ -15917,6 +15926,12 @@ fn test_lambdas() {
"map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2))",
);
dialects.verified_expr("transform(array(1, 2, 3), x -> x + 1)");

// Ensure all lambda variants are parsed correctly
dialects.verified_expr("a -> a * 2"); // Single parameter without type
dialects.verified_expr("a INT -> a * 2"); // Single parameter with type
dialects.verified_expr("(a, b) -> a * b"); // Multiple parameters without types
dialects.verified_expr("(a INT, b FLOAT) -> a * b"); // Multiple parameters with types
}

#[test]
Expand Down
16 changes: 14 additions & 2 deletions tests/sqlparser_databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ fn test_databricks_exists() {
]
),
Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::One(Ident::new("x")),
params: OneOrManyWithParens::One(LambdaFunctionParameter {
name: Ident::new("x"),
data_type: None
}),
body: Box::new(Expr::IsNull(Box::new(Expr::Identifier(Ident::new("x"))))),
syntax: LambdaSyntax::Arrow,
})
Expand Down Expand Up @@ -109,7 +112,16 @@ fn test_databricks_lambdas() {
]
),
Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
params: OneOrManyWithParens::Many(vec![
LambdaFunctionParameter {
name: Ident::new("p1"),
data_type: None
},
LambdaFunctionParameter {
name: Ident::new("p2"),
data_type: None
}
]),
body: Box::new(Expr::Case {
case_token: AttachedToken::empty(),
end_token: AttachedToken::empty(),
Expand Down