Skip to content

Commit 2567265

Browse files
authored
feat: support optional arg name for create function (#18848)
1 parent c2b5429 commit 2567265

File tree

7 files changed

+392
-63
lines changed

7 files changed

+392
-63
lines changed

src/query/ast/src/ast/statements/udf.rs

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,16 @@ pub enum UDFArgs {
3333
NameWithTypes(Vec<(Identifier, TypeName)>),
3434
}
3535

36+
#[derive(Debug, Clone, PartialEq, Drive, DriveMut)]
37+
pub enum LambdaUDFParams {
38+
Names(Vec<Identifier>),
39+
NameWithTypes(Vec<(Identifier, TypeName)>),
40+
}
41+
3642
#[derive(Debug, Clone, PartialEq, Drive, DriveMut)]
3743
pub enum UDFDefinition {
3844
LambdaUDF {
39-
parameters: Vec<Identifier>,
45+
parameters: LambdaUDFParams,
4046
definition: Box<Expr>,
4147
},
4248
UDFServer {
@@ -49,7 +55,7 @@ pub enum UDFDefinition {
4955
immutable: Option<bool>,
5056
},
5157
UDFScript {
52-
arg_types: Vec<TypeName>,
58+
arg_types: UDFArgs,
5359
return_type: TypeName,
5460
code: String,
5561
imports: Vec<String>,
@@ -68,7 +74,7 @@ pub enum UDFDefinition {
6874
language: String,
6975
},
7076
UDAFScript {
71-
arg_types: Vec<TypeName>,
77+
arg_types: UDFArgs,
7278
state_fields: Vec<UDAFStateField>,
7379
return_type: TypeName,
7480
imports: Vec<String>,
@@ -89,6 +95,17 @@ pub enum UDFDefinition {
8995
},
9096
}
9197

98+
impl LambdaUDFParams {
99+
pub fn names_iter(&self) -> Box<dyn Iterator<Item = &Identifier> + '_> {
100+
match self {
101+
LambdaUDFParams::Names(names) => Box::new(names.iter()),
102+
LambdaUDFParams::NameWithTypes(name_with_types) => {
103+
Box::new(name_with_types.iter().map(|(name, _)| name))
104+
}
105+
}
106+
}
107+
}
108+
92109
impl UDFArgs {
93110
pub fn len(&self) -> usize {
94111
match self {
@@ -100,6 +117,34 @@ impl UDFArgs {
100117
pub fn is_empty(&self) -> bool {
101118
self.len() == 0
102119
}
120+
121+
pub fn types_iter(&self) -> Box<dyn Iterator<Item = &TypeName> + '_> {
122+
match self {
123+
UDFArgs::Types(types) => Box::new(types.iter()),
124+
UDFArgs::NameWithTypes(name_with_types) => {
125+
Box::new(name_with_types.iter().map(|(_, ty)| ty))
126+
}
127+
}
128+
}
129+
}
130+
131+
impl Display for LambdaUDFParams {
132+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
133+
match self {
134+
LambdaUDFParams::Names(names) => {
135+
write_comma_separated_list(f, names)?;
136+
}
137+
LambdaUDFParams::NameWithTypes(name_with_types) => {
138+
write_comma_separated_list(
139+
f,
140+
name_with_types
141+
.iter()
142+
.map(|(name, ty)| format!("{name} {ty}")),
143+
)?;
144+
}
145+
}
146+
Ok(())
147+
}
103148
}
104149

105150
impl Display for UDFArgs {
@@ -128,8 +173,7 @@ impl Display for UDFDefinition {
128173
parameters,
129174
definition,
130175
} => {
131-
write!(f, "AS (")?;
132-
write_comma_separated_list(f, parameters)?;
176+
write!(f, "AS ({parameters}")?;
133177
write!(f, ") -> {definition}")?;
134178
}
135179
UDFDefinition::UDFServer {
@@ -174,8 +218,7 @@ impl Display for UDFDefinition {
174218
packages,
175219
immutable,
176220
} => {
177-
write!(f, "( ")?;
178-
write_comma_separated_list(f, arg_types)?;
221+
write!(f, "( {arg_types}")?;
179222
let imports = imports
180223
.iter()
181224
.map(|s| QuotedString(s, '\'').to_string())
@@ -270,8 +313,7 @@ impl Display for UDFDefinition {
270313
.map(|s| QuotedString(s, '\'').to_string())
271314
.join(",");
272315

273-
write!(f, "( ")?;
274-
write_comma_separated_list(f, arg_types)?;
316+
write!(f, "( {arg_types}")?;
275317
write!(f, " ) STATE {{ ")?;
276318
write_comma_separated_list(f, state_types)?;
277319
write!(

src/query/ast/src/parser/statement.rs

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5317,10 +5317,10 @@ pub fn udf_definition(i: Input) -> IResult<UDFDefinition> {
53175317

53185318
let lambda_udf = map(
53195319
rule! {
5320-
AS ~ "(" ~ #comma_separated_list0(ident) ~ ")"
5320+
AS ~ #lambda_udf_params
53215321
~ "->" ~ #expr
53225322
},
5323-
|(_, _, parameters, _, _, definition)| UDFDefinition::LambdaUDF {
5323+
|(_, parameters, _, definition)| UDFDefinition::LambdaUDF {
53245324
parameters,
53255325
definition: Box::new(definition),
53265326
},
@@ -5354,11 +5354,6 @@ pub fn udf_definition(i: Input) -> IResult<UDFDefinition> {
53545354
address_or_code,
53555355
)| {
53565356
if address_or_code.1 {
5357-
let UDFArgs::Types(arg_types) = arg_types else {
5358-
return Err(nom::Err::Failure(ErrorKind::Other(
5359-
"UDFScript parameters can only be of type",
5360-
)));
5361-
};
53625357
Ok(UDFDefinition::UDFScript {
53635358
arg_types,
53645359
return_type,
@@ -5439,11 +5434,6 @@ pub fn udf_definition(i: Input) -> IResult<UDFDefinition> {
54395434
address_or_code,
54405435
)| {
54415436
if address_or_code.1 {
5442-
let UDFArgs::Types(arg_types) = arg_types else {
5443-
return Err(nom::Err::Failure(ErrorKind::Other(
5444-
"UDAFScript parameters can only be of type",
5445-
)));
5446-
};
54475437
Ok(UDFDefinition::UDAFScript {
54485438
arg_types,
54495439
state_fields: state_types,
@@ -5476,10 +5466,30 @@ pub fn udf_definition(i: Input) -> IResult<UDFDefinition> {
54765466
);
54775467

54785468
rule!(
5479-
#lambda_udf: "AS (<parameter>, ...) -> <definition expr>"
5480-
| #udaf: "(<arg_type>, ...) STATE {<state_field>, ...} RETURNS <return_type> LANGUAGE <language> { ADDRESS=<udf_server_address> | AS <language_codes> } "
5481-
| #udf: "(<arg_type>, ...) RETURNS <return_type> LANGUAGE <language> HANDLER=<handler> { ADDRESS=<udf_server_address> | AS <language_codes> } "
5482-
| #scalar_udf_or_udtf: "(<arg_type>, ...) RETURNS <return body> AS <sql> }"
5469+
#lambda_udf: "AS (<parameter [parameter type]>, ...) -> <definition expr>"
5470+
| #udaf: "(<[arg_name] arg_type>, ...) STATE {<state_field>, ...} RETURNS <return_type> LANGUAGE <language> { ADDRESS=<udf_server_address> | AS <language_codes> } "
5471+
| #udf: "(<[arg_name] arg_type>, ...) RETURNS <return_type> LANGUAGE <language> HANDLER=<handler> { ADDRESS=<udf_server_address> | AS <language_codes> } "
5472+
| #scalar_udf_or_udtf: "(<arg_name arg_type>, ...) RETURNS <return body> AS <sql> }"
5473+
)(i)
5474+
}
5475+
5476+
fn lambda_udf_params(i: Input) -> IResult<LambdaUDFParams> {
5477+
let names = map(
5478+
rule! {
5479+
"(" ~ #comma_separated_list0(ident) ~ ")"
5480+
},
5481+
|(_, names, _)| LambdaUDFParams::Names(names),
5482+
);
5483+
let name_with_types = map(
5484+
rule! {
5485+
"(" ~ #comma_separated_list0(udtf_arg) ~ ")"
5486+
},
5487+
|(_, name_with_types, _)| LambdaUDFParams::NameWithTypes(name_with_types),
5488+
);
5489+
5490+
rule!(
5491+
#names: "(<arg_name>, ...)"
5492+
| #name_with_types: "(<arg_name arg_type>, ...)"
54835493
)(i)
54845494
}
54855495

src/query/ast/tests/it/parser.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,7 @@ SELECT * from s;"#,
838838
r#"GRANT OWNERSHIP ON UDF f1 TO ROLE 'd20_0015_owner';"#,
839839
r#"attach table t 's3://a' connection=(access_key_id ='x' secret_access_key ='y' endpoint_url='http://127.0.0.1:9900')"#,
840840
r#"CREATE FUNCTION IF NOT EXISTS isnotempty AS(p) -> not(is_null(p));"#,
841+
r#"CREATE FUNCTION IF NOT EXISTS isnotempty AS(p INT) -> not(is_null(p));"#,
841842
r#"CREATE OR REPLACE FUNCTION isnotempty_test_replace AS(p) -> not(is_null(p)) DESC = 'This is a description';"#,
842843
r#"CREATE OR REPLACE FUNCTION isnotempty_test_replace (p STRING) RETURNS BOOL AS $$ not(is_null(p)) $$;"#,
843844
r#"CREATE FUNCTION binary_reverse (BINARY) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815';"#,
@@ -874,6 +875,8 @@ SELECT * from s;"#,
874875
r#"DROP FUNCTION isnotempty;"#,
875876
r#"CREATE FUNCTION IF NOT EXISTS my_agg (INT) STATE { s STRING } RETURNS BOOLEAN LANGUAGE javascript ADDRESS = 'http://0.0.0.0:8815';"#,
876877
r#"CREATE FUNCTION IF NOT EXISTS my_agg (INT) STATE { s STRING, i INT NOT NULL } RETURNS BOOLEAN LANGUAGE javascript AS 'some code';"#,
878+
r#"CREATE FUNCTION IF NOT EXISTS my_agg (a INT) STATE { s STRING } RETURNS BOOLEAN LANGUAGE javascript ADDRESS = 'http://0.0.0.0:8815';"#,
879+
r#"CREATE FUNCTION IF NOT EXISTS my_agg (a INT) STATE { s STRING, i INT NOT NULL } RETURNS BOOLEAN LANGUAGE javascript AS 'some code';"#,
877880
r#"ALTER FUNCTION my_agg (INT) STATE { s STRING } RETURNS BOOLEAN LANGUAGE javascript AS 'some code';"#,
878881
r#"
879882
EXECUTE IMMEDIATE

src/query/ast/tests/it/testdata/stmt-error.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,7 +1043,7 @@ error:
10431043
| | | | |
10441044
| | | | while parsing `(<expr> [, ...])`
10451045
| | | while parsing expression
1046-
| | while parsing AS (<parameter>, ...) -> <definition expr>
1046+
| | while parsing AS (<parameter [parameter type]>, ...) -> <definition expr>
10471047
| while parsing `CREATE [OR REPLACE] FUNCTION [IF NOT EXISTS] <udf_name> <udf_definition> [DESC = <description>]`
10481048

10491049

@@ -1056,7 +1056,7 @@ error:
10561056
1 | CREATE FUNCTION my_agg (INT) STATE { s STRING } RETURNS BOOLEAN LANGUAGE javascript HANDLER = 'my_agg' ADDRESS = 'http://0.0.0.0:8815';
10571057
| ------ - ^^^^^^^ unexpected `HANDLER`, expecting `HEADERS`, `ADDRESS`, `PACKAGES`, `AS`, or `IMPORTS`
10581058
| | |
1059-
| | while parsing (<arg_type>, ...) STATE {<state_field>, ...} RETURNS <return_type> LANGUAGE <language> { ADDRESS=<udf_server_address> | AS <language_codes> }
1059+
| | while parsing (<[arg_name] arg_type>, ...) STATE {<state_field>, ...} RETURNS <return_type> LANGUAGE <language> { ADDRESS=<udf_server_address> | AS <language_codes> }
10601060
| while parsing `CREATE [OR REPLACE] FUNCTION [IF NOT EXISTS] <udf_name> <udf_definition> [DESC = <description>]`
10611061

10621062

@@ -1069,7 +1069,7 @@ error:
10691069
1 | CREATE FUNCTION my_agg (INT) STATE { s STRIN } RETURNS BOOLEAN LANGUAGE javascript ADDRESS = 'http://0.0.0.0:8815';
10701070
| ------ - ^^^^^ unexpected `STRIN`, expecting `STRING`, `SIGNED`, `INTERVAL`, `TINYINT`, `VARIANT`, `SMALLINT`, `TINYBLOB`, `VARBINARY`, `INT8`, `JSON`, `INT16`, `INT32`, `INT64`, `UINT8`, `BIGINT`, `UINT16`, `UINT32`, `UINT64`, `BINARY`, `INTEGER`, `DATETIME`, `NUMERIC`, `TIMESTAMP`, `UNSIGNED`, `STAGE_LOCATION`, `REAL`, `DATE`, `CHAR`, `TEXT`, `ARRAY`, `TUPLE`, `VECTOR`, `BOOLEAN`, `DECIMAL`, `VARCHAR`, `LONGBLOB`, `NULLABLE`, `CHARACTER`, `GEOGRAPHY`, `MEDIUMBLOB`, `BITMAP`, `}`, `BOOL`, `INT`, `FLOAT32`, `FLOAT`, `FLOAT64`, `DOUBLE`, `MAP`, `BLOB`, or `GEOMETRY`
10711071
| | |
1072-
| | while parsing (<arg_type>, ...) STATE {<state_field>, ...} RETURNS <return_type> LANGUAGE <language> { ADDRESS=<udf_server_address> | AS <language_codes> }
1072+
| | while parsing (<[arg_name] arg_type>, ...) STATE {<state_field>, ...} RETURNS <return_type> LANGUAGE <language> { ADDRESS=<udf_server_address> | AS <language_codes> }
10731073
| while parsing `CREATE [OR REPLACE] FUNCTION [IF NOT EXISTS] <udf_name> <udf_definition> [DESC = <description>]`
10741074

10751075

0 commit comments

Comments
 (0)