22
22
#include < unordered_set>
23
23
24
24
#include " velox/common/base/Exceptions.h"
25
+ #include " velox/common/fuzzer/ConstrainedGenerators.h"
25
26
#include " velox/exec/fuzzer/FuzzerUtil.h"
26
27
#include " velox/expression/Expr.h"
27
28
#include " velox/expression/FunctionSignature.h"
@@ -272,11 +273,14 @@ ExpressionFuzzer::ExpressionFuzzer(
272
273
const std::shared_ptr<VectorFuzzer>& vectorFuzzer,
273
274
const std::optional<ExpressionFuzzer::Options>& options,
274
275
const std::unordered_map<std::string, std::shared_ptr<ArgGenerator>>&
275
- argGenerators)
276
+ argGenerators,
277
+ const std::unordered_map<std::string, std::shared_ptr<ArgValuesGenerator>>&
278
+ argsOverrideFuncs)
276
279
: options_(options.value_or(Options())),
277
280
vectorFuzzer_ (vectorFuzzer),
278
- state{rng_, std::max (1 , options_.maxLevelOfNesting )},
279
- argGenerators_ (argGenerators) {
281
+ state_{rng_, std::max (1 , options_.maxLevelOfNesting )},
282
+ argGenerators_ (argGenerators),
283
+ funcArgOverrides_{argsOverrideFuncs} {
280
284
VELOX_CHECK (vectorFuzzer, " Vector fuzzer must be provided" );
281
285
seed (initialSeed);
282
286
@@ -432,10 +436,6 @@ ExpressionFuzzer::ExpressionFuzzer(
432
436
addToTypeToExpressionListByTicketTimes (" row" , " row_constructor" );
433
437
addToTypeToExpressionListByTicketTimes (kTypeParameterName , " dereference" );
434
438
}
435
-
436
- // Register function override (for cases where we want to restrict the types
437
- // or parameters we pass to functions).
438
- registerFuncOverride (&ExpressionFuzzer::generateSwitchArgs, " switch" );
439
439
}
440
440
441
441
bool ExpressionFuzzer::isSupportedSignature (
@@ -519,13 +519,6 @@ void ExpressionFuzzer::addToTypeToExpressionListByTicketTimes(
519
519
}
520
520
}
521
521
522
- template <typename TFunc>
523
- void ExpressionFuzzer::registerFuncOverride (
524
- TFunc func,
525
- const std::string& name) {
526
- funcArgOverrides_[name] = std::bind (func, this , std::placeholders::_1);
527
- }
528
-
529
522
void ExpressionFuzzer::seed (size_t seed) {
530
523
rng_.seed (seed);
531
524
vectorFuzzer_->reSeed (seed);
@@ -548,22 +541,23 @@ core::TypedExprPtr ExpressionFuzzer::generateArgConstant(const TypePtr& arg) {
548
541
// columns of the same type exist then there is a 30% chance that it will
549
542
// re-use one of them.
550
543
core::TypedExprPtr ExpressionFuzzer::generateArgColumn (const TypePtr& arg) {
551
- auto & listOfCandidateCols = state .typeToColumnNames_ [arg->toString ()];
544
+ auto & listOfCandidateCols = state_ .typeToColumnNames_ [arg->toString ()];
552
545
bool reuseColumn = options_.enableColumnReuse &&
553
546
!listOfCandidateCols.empty () && vectorFuzzer_->coinToss (0.3 );
554
547
555
548
if (!reuseColumn && options_.maxInputsThreshold .has_value () &&
556
- state .inputRowTypes_ .size () >= options_.maxInputsThreshold .value ()) {
549
+ state_ .inputRowTypes_ .size () >= options_.maxInputsThreshold .value ()) {
557
550
reuseColumn = !listOfCandidateCols.empty ();
558
551
}
559
552
560
553
if (!reuseColumn) {
561
- state.inputRowTypes_ .emplace_back (arg);
562
- state.inputRowNames_ .emplace_back (
563
- fmt::format (" c{}" , state.inputRowTypes_ .size () - 1 ));
564
- listOfCandidateCols.push_back (state.inputRowNames_ .back ());
554
+ state_.inputRowTypes_ .emplace_back (arg);
555
+ state_.inputRowNames_ .emplace_back (
556
+ fmt::format (" c{}" , state_.inputRowTypes_ .size () - 1 ));
557
+ state_.customInputGenerators_ .emplace_back (nullptr );
558
+ listOfCandidateCols.push_back (state_.inputRowNames_ .back ());
565
559
return std::make_shared<core::FieldAccessTypedExpr>(
566
- arg, state .inputRowNames_ .back ());
560
+ arg, state_ .inputRowNames_ .back ());
567
561
}
568
562
size_t chosenColIndex = rand32 (0 , listOfCandidateCols.size () - 1 );
569
563
return std::make_shared<core::FieldAccessTypedExpr>(
@@ -582,7 +576,7 @@ core::TypedExprPtr ExpressionFuzzer::generateArg(const TypePtr& arg) {
582
576
// - Lambdas
583
577
// - Try
584
578
if (argClass >= kArgExpression ) {
585
- if (state .remainingLevelOfNesting_ > 0 ) {
579
+ if (state_ .remainingLevelOfNesting_ > 0 ) {
586
580
return generateExpression (arg);
587
581
}
588
582
argClass = rand32 (0 , 1 );
@@ -732,18 +726,19 @@ std::vector<core::TypedExprPtr> ExpressionFuzzer::generateSwitchArgs(
732
726
733
727
ExpressionFuzzer::FuzzedExpressionData ExpressionFuzzer::fuzzExpressions (
734
728
const RowTypePtr& outType) {
735
- state .reset ();
729
+ state_ .reset ();
736
730
VELOX_CHECK_EQ (
737
- state .remainingLevelOfNesting_ , std::max (1 , options_.maxLevelOfNesting ));
731
+ state_ .remainingLevelOfNesting_ , std::max (1 , options_.maxLevelOfNesting ));
738
732
739
733
std::vector<core::TypedExprPtr> expressions;
740
734
for (int i = 0 ; i < outType->size (); i++) {
741
735
expressions.push_back (generateExpression (outType->childAt (i)));
742
736
}
743
737
return {
744
738
std::move (expressions),
745
- ROW (std::move (state.inputRowNames_ ), std::move (state.inputRowTypes_ )),
746
- std::move (state.expressionStats_ )};
739
+ ROW (std::move (state_.inputRowNames_ ), std::move (state_.inputRowTypes_ )),
740
+ std::move (state_.customInputGenerators_ ),
741
+ std::move (state_.expressionStats_ )};
747
742
}
748
743
749
744
ExpressionFuzzer::FuzzedExpressionData ExpressionFuzzer::fuzzExpressions (
@@ -760,16 +755,16 @@ ExpressionFuzzer::FuzzedExpressionData ExpressionFuzzer::fuzzExpression() {
760
755
// chance that it will re-use one of them.
761
756
core::TypedExprPtr ExpressionFuzzer::generateExpression (
762
757
const TypePtr& returnType) {
763
- VELOX_CHECK_GT (state .remainingLevelOfNesting_ , 0 );
764
- --state .remainingLevelOfNesting_ ;
765
- auto guard = folly::makeGuard ([&] { ++state .remainingLevelOfNesting_ ; });
758
+ VELOX_CHECK_GT (state_ .remainingLevelOfNesting_ , 0 );
759
+ --state_ .remainingLevelOfNesting_ ;
760
+ auto guard = folly::makeGuard ([&] { ++state_ .remainingLevelOfNesting_ ; });
766
761
767
762
core::TypedExprPtr expression;
768
763
bool reuseExpression =
769
764
options_.enableExpressionReuse && vectorFuzzer_->coinToss (0.3 );
770
765
if (reuseExpression) {
771
- expression = state .expressionBank_ .getRandomExpression (
772
- returnType, state .remainingLevelOfNesting_ + 1 );
766
+ expression = state_ .expressionBank_ .getRandomExpression (
767
+ returnType, state_ .remainingLevelOfNesting_ + 1 );
773
768
if (expression) {
774
769
return expression;
775
770
}
@@ -796,11 +791,11 @@ core::TypedExprPtr ExpressionFuzzer::generateExpression(
796
791
797
792
auto exprTransformer = options_.exprTransformers .find (chosenFunctionName);
798
793
if (exprTransformer != options_.exprTransformers .end ()) {
799
- state .remainingLevelOfNesting_ -=
794
+ state_ .remainingLevelOfNesting_ -=
800
795
exprTransformer->second ->extraLevelOfNesting ();
801
796
}
802
797
803
- if (state .remainingLevelOfNesting_ >= 0 ) {
798
+ if (state_ .remainingLevelOfNesting_ >= 0 ) {
804
799
if (chosenFunctionName == " cast" ) {
805
800
expression = generateCastExpression (returnType);
806
801
} else if (chosenFunctionName == " row_constructor" ) {
@@ -825,7 +820,7 @@ core::TypedExprPtr ExpressionFuzzer::generateExpression(
825
820
if (expression) {
826
821
expression = exprTransformer->second ->transform (std::move (expression));
827
822
}
828
- state .remainingLevelOfNesting_ +=
823
+ state_ .remainingLevelOfNesting_ +=
829
824
exprTransformer->second ->extraLevelOfNesting ();
830
825
}
831
826
}
@@ -841,17 +836,32 @@ core::TypedExprPtr ExpressionFuzzer::generateExpression(
841
836
return generateArgColumn (returnType);
842
837
}
843
838
}
844
- state .expressionBank_ .insert (expression);
839
+ state_ .expressionBank_ .insert (expression);
845
840
return expression;
846
841
}
847
842
848
843
std::vector<core::TypedExprPtr> ExpressionFuzzer::getArgsForCallable (
849
844
const CallableSignature& callable) {
845
+ // Special case for switch because it has a variable number of arguments not
846
+ // specified in the signature. Other functions' argument override should be
847
+ // specified through funcArgOverrides_.
848
+ if (callable.name == " switch" ) {
849
+ return generateSwitchArgs (callable);
850
+ }
851
+
850
852
auto funcIt = funcArgOverrides_.find (callable.name );
851
853
if (funcIt == funcArgOverrides_.end ()) {
852
854
return generateArgs (callable);
853
855
}
854
- return funcIt->second (callable);
856
+ auto args = funcIt->second ->generate (
857
+ callable, vectorFuzzer_->getOptions (), rng_, state_);
858
+ for (auto i = 0 ; i < args.size (); ++i) {
859
+ // Generate arguments not specified in the override.
860
+ if (args[i] == nullptr ) {
861
+ args[i] = generateArg (callable.args .at (i), callable.constantArgs .at (i));
862
+ }
863
+ }
864
+ return args;
855
865
}
856
866
857
867
core::TypedExprPtr ExpressionFuzzer::getCallExprFromCallable (
@@ -1124,45 +1134,6 @@ core::TypedExprPtr ExpressionFuzzer::generateDereferenceExpression(
1124
1134
inputExpressions[0 ],
1125
1135
fmt::format (" row_field{}" , referencedIndex));
1126
1136
}
1127
- void ExpressionFuzzer::ExprBank::insert (const core::TypedExprPtr& expression) {
1128
- auto typeString = expression->type ()->toString ();
1129
- if (typeToExprsByLevel_.find (typeString) == typeToExprsByLevel_.end ()) {
1130
- typeToExprsByLevel_.insert (
1131
- {typeString, ExprsIndexedByLevel (maxLevelOfNesting_ + 1 )});
1132
- }
1133
- auto & expressionsByLevel = typeToExprsByLevel_[typeString];
1134
- int nestingLevel = getNestedLevel (expression);
1135
- VELOX_CHECK_LE (nestingLevel, maxLevelOfNesting_);
1136
- expressionsByLevel[nestingLevel].push_back (expression);
1137
- }
1138
-
1139
- core::TypedExprPtr ExpressionFuzzer::ExprBank::getRandomExpression (
1140
- const facebook::velox::TypePtr& returnType,
1141
- int uptoLevelOfNesting) {
1142
- VELOX_CHECK_LE (uptoLevelOfNesting, maxLevelOfNesting_);
1143
- auto typeString = returnType->toString ();
1144
- if (typeToExprsByLevel_.find (typeString) == typeToExprsByLevel_.end ()) {
1145
- return nullptr ;
1146
- }
1147
- auto & expressionsByLevel = typeToExprsByLevel_[typeString];
1148
- int totalToConsider = 0 ;
1149
- for (int i = 0 ; i <= uptoLevelOfNesting; i++) {
1150
- totalToConsider += expressionsByLevel[i].size ();
1151
- }
1152
- if (totalToConsider > 0 ) {
1153
- int choice = boost::random ::uniform_int_distribution<uint32_t >(
1154
- 0 , totalToConsider - 1 )(rng_);
1155
- for (int i = 0 ; i <= uptoLevelOfNesting; i++) {
1156
- if (choice >= expressionsByLevel[i].size ()) {
1157
- choice -= expressionsByLevel[i].size ();
1158
- continue ;
1159
- }
1160
- return expressionsByLevel[i][choice];
1161
- }
1162
- VELOX_CHECK (false , " Should have found an expression." );
1163
- }
1164
- return nullptr ;
1165
- }
1166
1137
1167
1138
TypePtr ExpressionFuzzer::fuzzReturnType () {
1168
1139
auto chooseFromConcreteSignatures = rand32 (0 , 1 );
0 commit comments