Skip to content

Commit 38dc6b6

Browse files
authored
HIVE-28792: Wrong results when query has function call with char parameter type in case expression (Krisztian Kasa, reviewed by Stamatis Zampetakis)
1 parent 4108ec4 commit 38dc6b6

File tree

20 files changed

+387
-536
lines changed

20 files changed

+387
-536
lines changed

parser/src/java/org/apache/hadoop/hive/ql/parse/IdentifiersParser.g

+32-9
Original file line numberDiff line numberDiff line change
@@ -341,24 +341,47 @@ castExpression
341341
-> ^(TOK_FUNCTION {adaptor.create(Identifier, "cast_format")} NumberLiteral[Integer.toString(((CommonTree)toType.getTree()).token.getType())] expression StringLiteral NumberLiteral[((CommonTree)toType.getTree()).getChild(0).getText()])
342342
;
343343

344-
caseExpression
344+
whenExpression
345345
@init { gParent.pushMsg("case expression", state); }
346346
@after { gParent.popMsg(state); }
347347
:
348-
KW_CASE expression
349-
(KW_WHEN expression KW_THEN expression)+
348+
KW_CASE
349+
( KW_WHEN expression KW_THEN expression)+
350350
(KW_ELSE expression)?
351-
KW_END -> ^(TOK_FUNCTION KW_CASE expression*)
351+
KW_END -> ^(TOK_FUNCTION KW_WHEN expression*)
352352
;
353353

354-
whenExpression
354+
// Make caseExpression to build a whenExpression tree
355+
// Rewrite
356+
// CASE a
357+
// WHEN b THEN c
358+
// [WHEN d THEN e]* [ELSE f]
359+
// END
360+
// to
361+
// CASE
362+
// WHEN a=b THEN c
363+
// [WHEN a=d THEN e]* [ELSE f]
364+
// END
365+
caseExpression
355366
@init { gParent.pushMsg("case expression", state); }
356367
@after { gParent.popMsg(state); }
357368
:
358-
KW_CASE
359-
( KW_WHEN expression KW_THEN expression)+
360-
(KW_ELSE expression)?
361-
KW_END -> ^(TOK_FUNCTION KW_WHEN expression*)
369+
KW_CASE caseOperand=expression
370+
// Pass the case operand to the rule parses the when branches
371+
whenBranches[$caseOperand.tree]
372+
(KW_ELSE elseResult=expression)?
373+
KW_END -> ^(TOK_FUNCTION Identifier["when"] whenBranches $elseResult?)
374+
;
375+
376+
whenBranches[CommonTree caseOperand]
377+
:
378+
(whenExpressionBranch[caseOperand] KW_THEN! expression)+
379+
;
380+
381+
whenExpressionBranch[CommonTree caseOperand]
382+
:
383+
KW_WHEN when=expression
384+
-> ^(EQUAL["="] {$caseOperand} $when)
362385
;
363386

364387
floorExpression
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.hadoop.hive.ql.parse;
19+
20+
import org.junit.Assert;
21+
import org.junit.Test;
22+
23+
public class TestParseCase {
24+
ParseDriver parseDriver = new ParseDriver();
25+
26+
@Test
27+
public void testParseCaseWithOperandAndOneBranch() throws Exception {
28+
ASTNode tree = parseDriver.parseSelect("select case upper(col1) when 'A' then 'OK' else 'N/A' end from t1", null).getTree();
29+
30+
String result = "\n" +
31+
"TOK_SELECT\n" +
32+
" TOK_SELEXPR\n" +
33+
" TOK_FUNCTION\n" +
34+
" when\n" +
35+
" =\n" +
36+
" TOK_FUNCTION\n" +
37+
" upper\n" +
38+
" TOK_TABLE_OR_COL\n" +
39+
" col1\n" +
40+
" 'A'\n" +
41+
" 'OK'\n" +
42+
" 'N/A'\n";
43+
44+
Assert.assertEquals(result, tree.dump());
45+
}
46+
47+
@Test
48+
public void testParseCaseWithOperandAndMultipleBranches() throws Exception {
49+
ASTNode tree = parseDriver.parseSelect(
50+
"select case a" +
51+
" when 'B' then 'bean'" +
52+
" when 'A' then 'apple' else 'N/A' end from t1", null).getTree();
53+
54+
String result = "\n" +
55+
"TOK_SELECT\n" +
56+
" TOK_SELEXPR\n" +
57+
" TOK_FUNCTION\n" +
58+
" when\n" +
59+
" =\n" +
60+
" TOK_TABLE_OR_COL\n" +
61+
" a\n" +
62+
" 'B'\n" +
63+
" 'bean'\n" +
64+
" =\n" +
65+
" TOK_TABLE_OR_COL\n" +
66+
" a\n" +
67+
" 'A'\n" +
68+
" 'apple'\n" +
69+
" 'N/A'\n";
70+
71+
Assert.assertEquals(result, tree.dump());
72+
}
73+
74+
@Test
75+
public void testParseCaseWithOperandAndNoElse() throws Exception {
76+
ASTNode tree = parseDriver.parseSelect("select case a when 'A' then 'OK' end from t1", null).getTree();
77+
78+
String result = "\n" +
79+
"TOK_SELECT\n" +
80+
" TOK_SELEXPR\n" +
81+
" TOK_FUNCTION\n" +
82+
" when\n" +
83+
" =\n" +
84+
" TOK_TABLE_OR_COL\n" +
85+
" a\n" +
86+
" 'A'\n" +
87+
" 'OK'\n";
88+
89+
Assert.assertEquals(result, tree.dump());
90+
}
91+
}

ql/src/java/org/apache/hadoop/hive/ql/exec/ExprNodeGenericFuncEvaluator.java

+1-5
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
package org.apache.hadoop.hive.ql.exec;
2020

21-
import com.google.common.base.Preconditions;
22-
2321
import org.slf4j.Logger;
2422
import org.slf4j.LoggerFactory;
2523
import org.apache.hadoop.conf.Configuration;
@@ -30,7 +28,6 @@
3028
import org.apache.hadoop.hive.ql.session.SessionState;
3129
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
3230
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBaseCompare;
33-
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFCase;
3431
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFWhen;
3532
import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector;
3633
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -118,8 +115,7 @@ public ExprNodeGenericFuncEvaluator(ExprNodeGenericFuncDesc expr, Configuration
118115
}
119116
}
120117
genericUDF = expr.getGenericUDF();
121-
if (isEager &&
122-
(genericUDF instanceof GenericUDFCase || genericUDF instanceof GenericUDFWhen)) {
118+
if (isEager && genericUDF instanceof GenericUDFWhen) {
123119
throw new HiveException("Stateful expressions cannot be used inside of CASE");
124120
}
125121
}

ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java

-1
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,6 @@ public final class FunctionRegistry {
590590
system.registerGenericUDF("create_union", GenericUDFUnion.class);
591591
system.registerGenericUDF("extract_union", GenericUDFExtractUnion.class);
592592

593-
system.registerGenericUDF("case", GenericUDFCase.class);
594593
system.registerGenericUDF("when", GenericUDFWhen.class);
595594
system.registerGenericUDF("nullif", GenericUDFNullif.class);
596595
system.registerGenericUDF("hash", GenericUDFHash.class);

ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java

+3-4
Original file line numberDiff line numberDiff line change
@@ -1474,17 +1474,16 @@ private static boolean isNonVectorizedPathUDF(ExprNodeGenericFuncDesc expr,
14741474
} else if (gudf instanceof GenericUDFFromUnixTime && isIntFamily(arg0Type(expr))
14751475
|| (gudf instanceof GenericUDFTimestamp && isStringFamily(arg0Type(expr)))
14761476

1477-
/* GenericUDFCase and GenericUDFWhen are implemented with the UDF Adaptor because
1478-
* of their complexity and generality. In the future, variations of these
1477+
/* GenericUDFWhen is implemented with the UDF Adaptor because
1478+
* of its complexity and generality. In the future, variations of this
14791479
* can be optimized to run faster for the vectorized code path. For example,
14801480
* CASE col WHEN 1 then "one" WHEN 2 THEN "two" ELSE "other" END
1481-
* is an example of a GenericUDFCase that has all constant arguments
1481+
* is an example when all constant arguments
14821482
* except for the first argument. This is probably a common case and a
14831483
* good candidate for a fast, special-purpose VectorExpression. Then
14841484
* the UDF Adaptor code path could be used as a catch-all for
14851485
* non-optimized general cases.
14861486
*/
1487-
|| gudf instanceof GenericUDFCase
14881487
|| gudf instanceof GenericUDFWhen) {
14891488
return true;
14901489
} else // between has 4 args here, but can be vectorized like this

ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConstantPropagateProcFactory.java

+2-71
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredJavaObject;
6868
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBaseCompare;
6969
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBridge;
70-
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFCase;
7170
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFCoalesce;
7271
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd;
7372
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual;
@@ -564,13 +563,13 @@ private static ExprNodeDesc shortcutFunction(GenericUDF udf, List<ExprNodeDesc>
564563
ExprNodeGenericFuncDesc caseOrWhenexpr = null;
565564
if (newExprs.get(0) instanceof ExprNodeGenericFuncDesc) {
566565
caseOrWhenexpr = (ExprNodeGenericFuncDesc) newExprs.get(0);
567-
if (caseOrWhenexpr.getGenericUDF() instanceof GenericUDFWhen || caseOrWhenexpr.getGenericUDF() instanceof GenericUDFCase) {
566+
if (caseOrWhenexpr.getGenericUDF() instanceof GenericUDFWhen) {
568567
foundUDFInFirst = true;
569568
}
570569
}
571570
if (!foundUDFInFirst && newExprs.get(1) instanceof ExprNodeGenericFuncDesc) {
572571
caseOrWhenexpr = (ExprNodeGenericFuncDesc) newExprs.get(1);
573-
if (!(caseOrWhenexpr.getGenericUDF() instanceof GenericUDFWhen || caseOrWhenexpr.getGenericUDF() instanceof GenericUDFCase)) {
572+
if (!(caseOrWhenexpr.getGenericUDF() instanceof GenericUDFWhen)) {
574573
return null;
575574
}
576575
}
@@ -596,21 +595,6 @@ private static ExprNodeDesc shortcutFunction(GenericUDF udf, List<ExprNodeDesc>
596595
ExprNodeGenericFuncDesc newCaseOrWhenExpr = ExprNodeGenericFuncDesc.newInstance(childUDF,
597596
caseOrWhenexpr.getFuncText(), children);
598597
return newCaseOrWhenExpr;
599-
} else if (childUDF instanceof GenericUDFCase) {
600-
for (i = 2; i < children.size(); i+=2) {
601-
children.set(i, ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPEqual(),
602-
Lists.newArrayList(children.get(i),newExprs.get(foundUDFInFirst ? 1 : 0))));
603-
}
604-
if(children.size() % 2 == 0) {
605-
i = children.size()-1;
606-
children.set(i, ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPEqual(),
607-
Lists.newArrayList(children.get(i),newExprs.get(foundUDFInFirst ? 1 : 0))));
608-
}
609-
// after constant folding of child expression the return type of UDFCase might have changed,
610-
// so recreate the expression
611-
ExprNodeGenericFuncDesc newCaseOrWhenExpr = ExprNodeGenericFuncDesc.newInstance(childUDF,
612-
caseOrWhenexpr.getFuncText(), children);
613-
return newCaseOrWhenExpr;
614598
} else {
615599
// cant happen
616600
return null;
@@ -769,59 +753,6 @@ private static ExprNodeDesc shortcutFunction(GenericUDF udf, List<ExprNodeDesc>
769753
}
770754
}
771755

772-
if (udf instanceof GenericUDFCase) {
773-
// HIVE-9644 Attempt to fold expression like :
774-
// where (case ss_sold_date when '1998-01-01' then 1=1 else null=1 end);
775-
// where ss_sold_date= '1998-01-01' ;
776-
if (!(newExprs.size() == 3 || newExprs.size() == 4)) {
777-
// In general case can have unlimited # of branches,
778-
// we currently only handle either 1 or 2 branch.
779-
return null;
780-
}
781-
ExprNodeDesc thenExpr = newExprs.get(2);
782-
ExprNodeDesc elseExpr = newExprs.size() == 4 ? newExprs.get(3) :
783-
new ExprNodeConstantDesc(newExprs.get(2).getTypeInfo(),null);
784-
785-
if (thenExpr instanceof ExprNodeConstantDesc && elseExpr instanceof ExprNodeConstantDesc) {
786-
ExprNodeConstantDesc constThen = (ExprNodeConstantDesc) thenExpr;
787-
ExprNodeConstantDesc constElse = (ExprNodeConstantDesc) elseExpr;
788-
Object thenVal = constThen.getValue();
789-
Object elseVal = constElse.getValue();
790-
if (thenVal == null) {
791-
if (null == elseVal) {
792-
return thenExpr;
793-
} else if (op instanceof FilterOperator) {
794-
return Boolean.TRUE.equals(elseVal) ? ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPNotEqual(), newExprs.subList(0, 2)) :
795-
Boolean.FALSE.equals(elseVal) ? elseExpr : null;
796-
} else {
797-
return null;
798-
}
799-
} else if (null == elseVal && op instanceof FilterOperator) {
800-
return Boolean.TRUE.equals(thenVal) ? ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPEqual(), newExprs.subList(0, 2)) :
801-
Boolean.FALSE.equals(thenVal) ? thenExpr : null;
802-
} else if(thenVal.equals(elseVal)){
803-
return thenExpr;
804-
} else if (thenVal instanceof Boolean && elseVal instanceof Boolean) {
805-
ExprNodeGenericFuncDesc equal = ExprNodeGenericFuncDesc.newInstance(
806-
new GenericUDFOPEqual(), newExprs.subList(0, 2));
807-
List<ExprNodeDesc> children = new ArrayList<>();
808-
children.add(equal);
809-
children.add(new ExprNodeConstantDesc(false));
810-
ExprNodeGenericFuncDesc func = ExprNodeGenericFuncDesc.newInstance(new GenericUDFCoalesce(),
811-
children);
812-
if (Boolean.TRUE.equals(thenVal)) {
813-
return func;
814-
} else {
815-
List<ExprNodeDesc> exprs = new ArrayList<>();
816-
exprs.add(func);
817-
return ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPNot(), exprs);
818-
}
819-
} else {
820-
return null;
821-
}
822-
}
823-
}
824-
825756
if (udf instanceof GenericUDFUnixTimeStamp) {
826757
if (newExprs.size() >= 1) {
827758
// unix_timestamp(args) -> to_unix_timestamp(args)

ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRexExecutorImpl.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public HiveRexExecutorImpl() {
4747

4848
@Override
4949
public void reduce(RexBuilder rexBuilder, List<RexNode> constExps, List<RexNode> reducedValues) {
50-
RexNodeConverter rexNodeConverter = new RexNodeConverter(rexBuilder, rexBuilder.getTypeFactory());
50+
RexNodeConverter rexNodeConverter = new RexNodeConverter(rexBuilder);
5151
for (RexNode rexNode : constExps) {
5252
// initialize the converter
5353
ExprNodeConverter converter = new ExprNodeConverter("", null, null, null,

0 commit comments

Comments
 (0)