11package io .codemodder .remediation .regexdos ;
22
3- import com .github .javaparser .StaticJavaParser ;
43import com .github .javaparser .ast .CompilationUnit ;
54import com .github .javaparser .ast .Node ;
65import com .github .javaparser .ast .NodeList ;
7- import com .github .javaparser .ast .body .ClassOrInterfaceDeclaration ;
8- import com .github .javaparser .ast .body .MethodDeclaration ;
9- import com .github .javaparser .ast .expr .Expression ;
10- import com .github .javaparser .ast .expr .IntegerLiteralExpr ;
11- import com .github .javaparser .ast .expr .LambdaExpr ;
12- import com .github .javaparser .ast .expr .MethodCallExpr ;
6+ import com .github .javaparser .ast .expr .*;
7+ import io .codemodder .DependencyGAV ;
138import io .codemodder .ast .ASTTransforms ;
149import io .codemodder .ast .ASTs ;
1510import io .codemodder .ast .LocalDeclaration ;
1813import io .codemodder .remediation .SuccessOrReason ;
1914import java .util .List ;
2015import java .util .Optional ;
21- import java .util .concurrent .Callable ;
22- import java .util .concurrent .Executors ;
2316
2417/** Adds a timeout function and wraps regex match call with it * */
2518final class RegexDoSFixStrategy extends MatchAndFixStrategy {
@@ -51,33 +44,6 @@ public boolean match(final Node node) {
5144 .isPresent ();
5245 }
5346
54- private static void addTimeoutMethodIfMissing (
55- final CompilationUnit cu , final ClassOrInterfaceDeclaration classDecl ) {
56- final String method =
57- """
58- public <E> E executeWithTimeout(final Callable<E> action, final int timeout){
59- Future<E> maybeResult = Executors.newSingleThreadExecutor().submit(action);
60- try{
61- return maybeResult.get(timeout, TimeUnit.MILLISECONDS);
62- }catch(Exception e){
63- throw new RuntimeException("Failed to execute within time limit.");
64- }
65- }
66- """ ;
67- boolean filterMethodPresent =
68- classDecl .findAll (MethodDeclaration .class ).stream ()
69- .anyMatch (
70- md ->
71- md .getNameAsString ().equals ("executeWithTimeout" )
72- && md .getParameters ().size () == 2 );
73- if (!filterMethodPresent ) {
74- classDecl .addMember (StaticJavaParser .parseMethodDeclaration (method ));
75- }
76- // Add needed import
77- ASTTransforms .addImportIfMissing (cu , Callable .class .getName ());
78- ASTTransforms .addImportIfMissing (cu , Executors .class .getName ());
79- }
80-
8147 @ Override
8248 public SuccessOrReason fix (final CompilationUnit cu , final Node node ) {
8349 // indirect case, assigned to a variable
@@ -91,21 +57,20 @@ public SuccessOrReason fix(final CompilationUnit cu, final Node node) {
9157 if (allValidMethodCalls .isEmpty ()) {
9258 return SuccessOrReason .reason ("Couldn't find any matching methods" );
9359 }
94- // Add executeWithTimout method to the encompassing class
95- var classDecl = call .findAncestor (ClassOrInterfaceDeclaration .class );
96- if (classDecl .isEmpty ()) {
97- return SuccessOrReason .reason ("Couldn't find encompassing class" );
98- }
99- classDecl .ifPresent (cd -> addTimeoutMethodIfMissing (cu , cd ));
60+
10061 for (var mce : allValidMethodCalls ) {
10162 // Wrap it with executeWithTimeout with a default 5000 of timeout
10263 var newCall =
10364 new MethodCallExpr (
65+ new NameExpr ("ExecuteWithTimeout" ),
10466 "executeWithTimeout" ,
105- new LambdaExpr (new NodeList <>(), mce .clone ()),
106- new IntegerLiteralExpr (DEFAULT_TIMEOUT ));
67+ new NodeList <>(
68+ new LambdaExpr (new NodeList <>(), mce .clone ()),
69+ new IntegerLiteralExpr (DEFAULT_TIMEOUT )));
10770 mce .replace (newCall );
10871 }
109- return SuccessOrReason .success ();
72+
73+ ASTTransforms .addImportIfMissing (cu , "io.github.pixee.security.ExecuteWithTimeout" );
74+ return SuccessOrReason .success (List .of (DependencyGAV .JAVA_SECURITY_TOOLKIT ));
11075 }
11176}
0 commit comments