diff --git a/.gitignore b/.gitignore
index 887c9e4..4c0b550 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,7 +8,7 @@ bin
# Ignore auto-generated benchmarking files and dependencies
benchmarking/jars/
benchmarking/datasets/
-benchmarking/outputs/
+benchmarking/results/
benchmarking/sources/
benchmarking/compiled_classes/
diff --git a/benchmarking/benchmarking.md b/benchmarking/benchmarking.md
index 2b67c8a..f796886 100644
--- a/benchmarking/benchmarking.md
+++ b/benchmarking/benchmarking.md
@@ -1,6 +1,6 @@
# Benchmarking
-# Stage Zero: Initiailization
+# Stage Zero
The first stage of benchmarking involves downloading all the necessary files
required, i.e. Jar files and the NJR-1 Dataset.
@@ -10,7 +10,8 @@ required, i.e. Jar files and the NJR-1 Dataset.
In order for benchmarking to run, the jar files for all the following
dependencies must be located in the `JARS_DIR` directory, in their respective
sub-folder (i.e. `JARS_DIR/errorprone`, `JARS_DIR/nullaway`, and
-`JARS_DIR/annotator`)
+`JARS_DIR/annotator`). If not present, they will be automaticlaly downloaded
+from the Maven Repository.
## Note: Different versions of the following dependencies may not be compatable. The newest versions of each project that are confirmed to work together are:
@@ -65,13 +66,3 @@ before NullAway can process them.
- [Checker-Qual](https://mvnrepository.com/artifact/org.checkerframework/checker-qual/)
- checker-qual contains annotations (type qualifiers) that a programmer writes
to specify Java code for type-checking by the Checker Framework.
-
-# Stage One: Annotation
-
-Stage One is the largest stage in terms of time consumption and computation. It
-involves running NullAwayAnnotator on the entire NJR-1 dataset in order to
-prepare it for refactoring, as well as to get an accurate count of the number of
-NullAway errors in the original programs, refactoring every program using VGR,
-and finally re-running annotator to get an updated error count. This cycle
-(annotate->refactor->annotate) is completed for each program in NJR-1
-sequentially.
diff --git a/benchmarking/run_benchmark.py b/benchmarking/run_benchmark.py
index b7c19b2..fd20a3b 100644
--- a/benchmarking/run_benchmark.py
+++ b/benchmarking/run_benchmark.py
@@ -1,20 +1,8 @@
# pyright: basic
-import argparse
-import csv
from datetime import datetime
import os
-import re
import shutil
-import subprocess
import sys
-from typing import TypedDict
-
-
-class BenchmarkingResult(TypedDict):
- benchmark: str
- initial_error_count: int | str # "Error" for failed runs; Error count otherwise
- refactored_error_count: int | str # "Error" for failed runs; Error count otherwise
-
BENCHMARKING_DIR = "./benchmarking" # Base directory for benchmarking inputs / outputs
DATASETS_DIR = (
@@ -26,9 +14,7 @@ class BenchmarkingResult(TypedDict):
)
DATASETS_REFACTORED_SAVE_DIR = f"{DATASETS_DIR}/old-runs/refactored" # Directory for datasets that will be modified
-OUTPUT_DIR = f"{BENCHMARKING_DIR}/outputs" # Directory for storing outputs
-OUTPUT_LOGS_DIR = f"{OUTPUT_DIR}/logs" # Directory for storing outputs
-RESULTS_DIR = f"{OUTPUT_DIR}/results" # Directory for storing result csvs
+OUTPUT_DIR = f"{BENCHMARKING_DIR}/results" # Directory for storing outputs
SRC_DIR = f"{BENCHMARKING_DIR}/sources" # Directory for storing text files listing source files for each project
COMPILED_CLASSES_DIR = (
f"{BENCHMARKING_DIR}/compiled_classes" # Directory for storing compiled_classes
@@ -41,8 +27,8 @@ class BenchmarkingResult(TypedDict):
ANNOTATOR_JAR_DIR = f"{JARS_DIR}/annotator"
PROCESSOR_JARS = [
{
- "PATH": f"{ERRORPRONE_JAR_DIR}/error_prone_core-2.35.1-with-dependencies.jar",
- "DOWNLOAD_URL": "https://repo1.maven.org/maven2/com/google/errorprone/error_prone_core/2.35.1/error_prone_core-2.35.1.jar",
+ "PATH": f"{ERRORPRONE_JAR_DIR}/error_prone_core-2.38.0-with-dependencies.jar",
+ "DOWNLOAD_URL": "https://repo1.maven.org/maven2/com/google/errorprone/error_prone_core/2.38.0/error_prone_core-2.38.0.jar",
},
{
"PATH": f"{ERRORPRONE_JAR_DIR}/dataflow-errorprone-3.49.3-eisop1.jar",
@@ -92,17 +78,13 @@ class BenchmarkingResult(TypedDict):
"-J--add-opens=jdk.compiler/com.sun.tools.javac.comp=ALL-UNNAMED",
]
-DEBUG = False
-
-benchmark_start_time_string = f"{datetime.now():%Y-%m-%d_%H:%M:%S}"
-
# The initialization stage for benchmarking
# Creates the necessary directories, saves old refactored datasets, confirms the existence of the necessary jar files, and downloads NJR-1 dataset if it has not been already.
def stage_zero():
print("Beginning Stage Zero: Initialization...")
- save_dir = f"{DATASETS_REFACTORED_SAVE_DIR}/{benchmark_start_time_string}"
+ save_dir = f"{DATASETS_REFACTORED_SAVE_DIR}/{datetime.now():%Y-%m-%d_%H:%M:%S}"
print(f"Saving existing refactored datasets to {save_dir}")
if os.path.exists(DATASETS_REFACTORED_DIR):
try:
@@ -116,8 +98,6 @@ def stage_zero():
print("Initializing benchmarking folders and datasets")
os.makedirs(SRC_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
- os.makedirs(OUTPUT_LOGS_DIR, exist_ok=True)
- os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(DATASETS_DIR, exist_ok=True)
os.makedirs(DATASETS_CACHE_DIR, exist_ok=True)
os.makedirs(COMPILED_CLASSES_DIR, exist_ok=True)
@@ -148,7 +128,7 @@ def stage_zero():
sys.exit(1)
print("Creating copy of NJR-1 datasets cache to refactor...")
- res = os.system(f"cp -a {DATASETS_CACHE_DIR} {DATASETS_REFACTORED_DIR}")
+ res = os.system(f"cp -av {DATASETS_CACHE_DIR} {DATASETS_REFACTORED_DIR}")
if res != 0:
print(f"Copy dataset cache failed with exit code {res}. Exiting Program")
sys.exit(1)
@@ -171,319 +151,11 @@ def stage_zero():
print("Benchmarking Stage Zero Completed\n")
-def stage_one():
- """
- Runs the full benchmarking routine (Annotate -> Count Errors -> Refactor -> Annotate -> Count Errors) for every dataset in the NJR-1 dataset collection and then summarizes the results.
- """
- datasets_list = os.listdir(DATASETS_REFACTORED_DIR)
-
- # List of data structures representing the results of a benchmark
- results: list[BenchmarkingResult] = []
-
- for dataset in datasets_list:
- print(f"Benchmarking {dataset}...")
- os.makedirs(f"{OUTPUT_DIR}/{dataset}", exist_ok=True)
-
- ## Step 1: Annotate dataset
- stage_one_annotate(dataset)
-
- ## Step 2: Count initial errors
- old_err_count = stage_one_count_errors(dataset)
- if old_err_count is None:
- print(f"Skipping {dataset} due to javac/NullAway crash.")
- results.append(
- {
- "benchmark": dataset,
- "initial_error_count": "Error",
- "refactored_error_count": "",
- }
- )
- continue
-
- ## Step 3: Refactor dataset
- stage_one_refactor(dataset)
-
- ## Step 4: Count errors after refactoring
- new_err_count = stage_one_count_errors(dataset)
- if new_err_count is None:
- print(f"Skipping {dataset} due to javac/NullAway crash after refactoring.")
- results.append(
- {
- "benchmark": dataset,
- "initial_error_count": old_err_count,
- "refactored_error_count": "Error",
- }
- )
- continue
-
- print(
- f"Succesfully benchmarked {dataset}. Errors: {old_err_count} --> {new_err_count}\n"
- )
- results.append(
- {
- "benchmark": dataset,
- "initial_error_count": old_err_count,
- "refactored_error_count": new_err_count,
- }
- )
- print(f"Finished benchmarking datasets.")
- print(f"Saving results to csv...")
-
- stage_one_save_results(results)
- return
-
-
-# Utility Functions
-def stage_one_annotate(dataset: str):
- """
- Runs NullAwayAnnotator on the passed dataset in order to prepare it for
- refactoring
- """
-
- print(f"Annotating {dataset}...")
-
- # Create config files
- os.makedirs(ANNOTATOR_OUT_DIR, exist_ok=True)
- with open(f"{ANNOTATOR_CONFIG}", "w+") as config_file:
- _ = config_file.write(f"{NULLAWAY_CONFIG}/\t{SCANNER_CONFIG}\n")
-
- # Clear annotator output folder (required for annotator to run)
- shutil.rmtree(ANNOTATOR_OUT_DIR + "/0", ignore_errors=True)
-
- build_cmd = " ".join(get_build_cmd(dataset))
- cwd = os.getcwd()
-
- annotate_cmd: list[str] = [
- "java",
- "-jar",
- ANNOTATOR_JAR,
- # Absolute path of an Empty Directory where all outputs of AnnotatorScanner and NullAway are serialized.
- "-d",
- ANNOTATOR_OUT_DIR,
- # Command to run Nullaway on target; Should be executable from anywhere
- "-bc",
- f'"cd {cwd} && {build_cmd}"',
- # Path to a TSV file containing value of config paths
- "-cp",
- ANNOTATOR_CONFIG,
- # Fully qualified name of the @Initializer annotation.
- "-i",
- "com.uber.nullaway.annotations.Initializer",
- # Checker name to be used for the analysis.
- "-cn",
- "NULLAWAY",
- # Max depth to traverse as part of the analysis search
- "--depth",
- "10",
- ]
- res = subprocess.run(annotate_cmd, text=True, capture_output=True)
- if res.returncode != 0:
- print(
- f"Annotation failed with exit code {res.returncode} for dataset {dataset}"
- )
- return
-
- output_log_path = f"{OUTPUT_DIR}/{dataset}/annotator.txt"
- with open(output_log_path, "w+") as f:
- f.write(f"CMD:\n\t{" ".join(annotate_cmd)}\n")
- f.write(f"STDOUT:\n\t{res.stdout}\n")
- f.write(f"STDERR:\n\t{res.stderr}\n")
-
- if DEBUG:
- print(
- f"Command used to annotate dataset {dataset}: \n\t{" ".join(annotate_cmd)}\n"
- )
-
- return
-
-
-def stage_one_refactor(dataset: str):
- """
- Runs VGR on the passed dataset
- """
-
- print(f"Refactoring {dataset}...")
-
- output_file = f"{OUTPUT_DIR}/{dataset}/refactoring.txt"
- dataset_path = f"{DATASETS_REFACTORED_DIR}/{dataset}"
-
- refactor_cmd: list[str] = ["./gradlew", "run", f"'--args={dataset_path} All'"]
-
- with open(output_file, "w+") as f:
- res = subprocess.run(
- " ".join(refactor_cmd) + f" &> {output_file}", shell=True, check=False
- )
-
- if res.returncode != 0:
- print(
- f"Running VGRTool failed with exit code {res.returncode} for dataset {dataset}. See {output_file} for more details."
- )
-
- if DEBUG:
- print(
- f"Refactor Command for dataset {dataset}: : {' '.join(refactor_cmd)} &> {output_file}"
- )
- return
-
-
-def stage_one_count_errors(dataset: str):
- """Builds the passed datsets and counts NullAway errors during the build process."""
- build_cmd = " ".join(get_build_cmd(dataset))
- log_file = f"{OUTPUT_LOGS_DIR}/{dataset}-error_count_log-{datetime.now():%Y-%m-%d_%H:%M:%S}.txt"
- output_file = (
- f"{OUTPUT_LOGS_DIR}/{dataset}-error_count-{benchmark_start_time_string}.txt"
- )
-
- # Build the dataset and redirect all outputs to a log file
- with open(log_file, "w+") as f:
- res = subprocess.run(
- build_cmd,
- stdout=f,
- stderr=subprocess.STDOUT,
- check=False,
- text=True,
- shell=True,
- )
- f.write(build_cmd)
-
- # Handle javac / NullAway crash
- # if res.returncode != 0:
- # print(
- # f"Building dataset {dataset} failed with exit code {res.returncode}. Skipping dataset..."
- # )
- # return None # Return None type so programs which are erroring do not look like real results
-
- # Read the log file and count occurrences of NullAway errors
- with open(log_file, "r") as f:
- error_count = len(re.findall(r"error: \[NullAway\]", f.read()))
-
- with open(output_file, "a") as f:
- f.write(f"Error Count: {error_count}\n")
-
- if DEBUG:
- print(f"Number of errors found for dataset {dataset}: {error_count}")
- return error_count
-
-
-def get_build_cmd(dataset: str):
- """
- Constructs the full 'javac' build command used to compile the passed dataset.
- """
- lib_dir = f"{DATASETS_REFACTORED_DIR}/{dataset}/lib"
- src_file = get_source_files(dataset)
- plugin_options = get_plugin_options(dataset)
-
- build_cmd: list[str] = ["javac"]
- build_cmd += ERROR_PRONE_EXPORTS
- build_cmd += [
- "-d",
- f"{COMPILED_CLASSES_DIR}",
- "-cp",
- f"{lib_dir}:{ANNOTATOR_JAR}",
- "-XDcompilePolicy=simple",
- "--should-stop=ifError=FLOW",
- "-processorpath",
- f"{PROCESSOR_JAR_PATHS}",
- f"'{plugin_options}'",
- "-Xmaxerrs",
- "0",
- "-Xmaxwarns",
- "0",
- f"@{src_file}",
- ]
- return build_cmd
-
-
-def get_source_files(dataset):
- find_srcs_command = [
- "find",
- f"{DATASETS_REFACTORED_DIR}/{dataset}/src",
- "-name",
- "*.java",
- ]
- src_file = f"{SRC_DIR}/{dataset}.txt"
- with open(src_file, "w+") as f:
- _ = subprocess.run(find_srcs_command, stdout=f)
- return src_file
-
-
-def get_plugin_options(dataset: str):
- """
- Generates the -Xplugin:ErrorProne option string, including a dynamically generated list of packages to annotate.
- """
- dataset_path = f"{DATASETS_REFACTORED_DIR}/{dataset}"
- find_pkgs_command = (
- f"find {dataset_path}"
- + " -name '*.java' -exec awk 'FNR==1 && /^package/ {print $2}' {} + | sed 's/;//' | sort -u | tr '\n\r' ',' | sed 's/,,/,/g' | sed 's/,$//'"
- )
-
- pkgs = subprocess.run(
- find_pkgs_command, shell=True, capture_output=True
- ).stdout.decode("utf-8")
-
- # Split the annotated packages
- annotated_pkgs = pkgs.strip()
- annotated_pkgs_arg = f"-XepOpt:NullAway:AnnotatedPackages={annotated_pkgs}"
-
- return f"-Xplugin:ErrorProne \
- -XepDisableAllChecks \
- -Xep:AnnotatorScanner:ERROR \
- -XepOpt:AnnotatorScanner:ConfigPath={SCANNER_CONFIG} \
- -Xep:NullAway:ERROR \
- -XepOpt:NullAway:SerializeFixMetadata=true \
- -XepOpt:NullAway:FixSerializationConfigPath={NULLAWAY_CONFIG} \
- {annotated_pkgs_arg}"
-
-
-def stage_one_save_results(results):
- """
- Saves benchmark results to a csv"
- """
-
- csv_path = f"{RESULTS_DIR}/results-{benchmark_start_time_string}"
-
- column_names = [
- "benchmark",
- "initial_error_count",
- "refactored_error_count",
- ]
-
- # Write CSV
- with open(csv_path, "w+") as f:
- writer = csv.DictWriter(f, fieldnames=column_names)
- writer.writeheader()
-
- for result in results:
- row = {
- "benchmark": result["benchmark"],
- "initial_error_count": (result["initial_error_count"]),
- "refactored_error_count": (result["refactored_error_count"]),
- }
- writer.writerow(row)
-
- print(f"Saved results to {csv_path}")
-
-
def run():
"""
Runs the full benchmarking routine for every dataset in the NJR-1 dataset collection and then summarizes the results.
"""
stage_zero()
- stage_one()
-
-
-def main():
- """Main entry point of the script."""
- global DEBUG
- argparser = argparse.ArgumentParser(description="Runs benchmark.")
- argparser.add_argument(
- "--debug", action="store_true", help="Enabling debugging statements."
- )
- args = argparser.parse_args()
- DEBUG = args.debug
-
- run()
-if __name__ == "__main__":
- main()
+run()
diff --git a/src/main/java/AddNullCheckBeforeDereferenceRefactoring.java b/src/main/java/AddNullCheckBeforeDereferenceRefactoring.java
index 2d081a3..364fd52 100644
--- a/src/main/java/AddNullCheckBeforeDereferenceRefactoring.java
+++ b/src/main/java/AddNullCheckBeforeDereferenceRefactoring.java
@@ -1,200 +1,202 @@
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import org.eclipse.jdt.core.dom.AST;
import org.eclipse.jdt.core.dom.ASTNode;
-import org.eclipse.jdt.core.dom.ArrayAccess;
-import org.eclipse.jdt.core.dom.Block;
+import org.eclipse.jdt.core.dom.Assignment;
import org.eclipse.jdt.core.dom.ConditionalExpression;
import org.eclipse.jdt.core.dom.Expression;
-import org.eclipse.jdt.core.dom.FieldAccess;
+import org.eclipse.jdt.core.dom.IBinding;
import org.eclipse.jdt.core.dom.IfStatement;
import org.eclipse.jdt.core.dom.InfixExpression;
-import org.eclipse.jdt.core.dom.MethodInvocation;
import org.eclipse.jdt.core.dom.NullLiteral;
import org.eclipse.jdt.core.dom.ParenthesizedExpression;
-import org.eclipse.jdt.core.dom.QualifiedName;
+import org.eclipse.jdt.core.dom.PrefixExpression;
import org.eclipse.jdt.core.dom.SimpleName;
import org.eclipse.jdt.core.dom.VariableDeclarationFragment;
import org.eclipse.jdt.core.dom.rewrite.ASTRewrite;
-
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
/**
- * This class represents a refactoring in which explicit null checks are added
- * before a value dereference
+ * A refactoring module that replaces checks on variables whose nullness is
+ * dependent on the nullness of another with another variable, by checking the
+ * original (independent) variable directly.
+ *
+ * Example:
+ *
+ *
{@code
+ * // Before:
+ * Class> dependentVar = (independentVar != null ? independentVar.getDependent() : null);
+ * if (dependentVar != null) {
+ * // ...
+ * }
+ *
+ * // After:
+ * Class> dependentVar = (independentVar != null ? independentVar.getDependent() : null);
+ * if (independentVar != null) {
+ * // ...
+ * }
+ * }
+ *
*/
public class AddNullCheckBeforeDereferenceRefactoring extends Refactoring {
public static final String NAME = "AddNullCheckBeforeDereferenceRefactoring";
+ private static final Logger LOGGER = LogManager.getLogger();
/**
- * Optional list of expressions identified as possibly null (to guide
- * applicability)
+ * List of dependent variables and the independent variable they rely on * Uses
+ * each variable's ({@link org.eclipse.jdt.core.dom.IVariableBinding}) as the
+ * key, ensuring global uniqueness. Two variables who have the same name but
+ * have different scopes will have different IBinding instances.
*/
- @SuppressWarnings("unused")
- private List possiblyNullExpressions;
-
- private static final Logger LOGGER = LogManager.getLogger();
+ private final Map validRefactors;
/** Default constructor (for RefactoringEngine integration) */
public AddNullCheckBeforeDereferenceRefactoring() {
- super(); // Call to base class (if it expects a name/ID)
- }
-
- /** Constructor that accepts a list of possibly-null expressions */
- public AddNullCheckBeforeDereferenceRefactoring(List possiblyNullExpressions) {
super();
- this.possiblyNullExpressions = possiblyNullExpressions;
+ validRefactors = new HashMap<>();
}
@Override
public boolean isApplicable(ASTNode node) {
- if (node instanceof MethodInvocation || node instanceof FieldAccess || node instanceof QualifiedName
- || node instanceof ArrayAccess) {
- return true;
+ if (node instanceof VariableDeclarationFragment varFrag) {
+ return isApplicable(varFrag);
}
if (node instanceof IfStatement ifStmt) {
- Expression condition = ifStmt.getExpression();
- if (condition instanceof InfixExpression infix
- && infix.getOperator() == InfixExpression.Operator.NOT_EQUALS) {
- Expression leftOperand = infix.getLeftOperand();
- Expression rightOperand = infix.getRightOperand();
-
- if ((leftOperand instanceof SimpleName && rightOperand instanceof NullLiteral)
- || (rightOperand instanceof SimpleName && leftOperand instanceof NullLiteral)) {
- LOGGER.debug("Found indirect null check in if-statement: {}", condition);
- return true;
- }
- }
+ return isApplicable(ifStmt);
}
+
+ if (node instanceof Assignment assignment) {
+ verifyRefactors(assignment);
+ }
+
+ LOGGER.debug("Node " + node.getClass().getSimpleName() + " is NOT applicable. Skipping.");
return false;
}
- @Override
- public void apply(ASTNode node, ASTRewrite rewriter) {
- LOGGER.debug("Processing ASTNode: {}", node.getClass().getSimpleName());
-
- AST ast = node.getAST();
-
- if (node instanceof MethodInvocation exprNode) {
- LOGGER.debug("Target Expression: " + (exprNode).getExpression());
- } else if (node instanceof FieldAccess exprNode) {
- LOGGER.debug("Target Expression: " + (exprNode).getExpression());
- } else if (node instanceof QualifiedName exprNode) {
- LOGGER.debug("Target Expression: " + (exprNode).getQualifier());
- } else if (node instanceof ArrayAccess exprNode) {
- LOGGER.debug("Target Expression: " + (exprNode).getArray());
- } else {
- LOGGER.debug("Node is not a dereferenceable expression.");
- return;
+ private boolean isApplicable(VariableDeclarationFragment var) {
+ Expression initializer = var.getInitializer();
+ if (initializer == null)
+ return false;
+ List varInitializerFragments = getSubExpressions(initializer);
+ AST ast = var.getAST();
+ for (Expression varInitFrag : varInitializerFragments) {
+
+ Expression condition;
+ if (varInitFrag instanceof ConditionalExpression ternary) {
+ if (ternary.getThenExpression() instanceof NullLiteral) {
+ // depObj != null when condition is false
+ ParenthesizedExpression tempParen = ast.newParenthesizedExpression();
+ tempParen.setExpression((Expression) ASTNode.copySubtree(ast, ternary.getExpression()));;
+
+ PrefixExpression tempPrefix = ast.newPrefixExpression();
+ tempPrefix.setOperator(PrefixExpression.Operator.NOT);
+ tempPrefix.setOperand(tempParen);
+ condition = tempPrefix;
+ } else if (ternary.getElseExpression() instanceof NullLiteral) {
+ // depObj != null when condition is true
+ condition = ternary.getExpression();
+ } else {
+ // Ternary must contain NullLiteral
+ continue;
+ }
+ LOGGER.debug("Found Ternary Assignment: %s", var.getName());
+ LOGGER.debug("Found Ternary Condition: %s", condition);
+ validRefactors.put(var.resolveBinding(), condition);
+ }
}
+ return false;
+ }
- ASTNode parentNode = node.getParent();
- VariableDeclarationFragment assignedVariable = null;
- IfStatement existingIfStatement = null;
+ private boolean isApplicable(IfStatement ifStmt) {
+ Expression ifStmtCondition = ifStmt.getExpression();
+ LOGGER.debug("Analyzing if-statement: %s", ifStmtCondition);
+ List conditionFragments = Refactoring.getSubExpressions(ifStmtCondition);
+ for (Expression condition : conditionFragments) {
+ if (!(condition instanceof InfixExpression infix)) {
+ continue;
+ }
- // 1️⃣ Find the variable assigned via a ternary operator
- while (parentNode != null) {
- if (parentNode instanceof VariableDeclarationFragment varDecl) {
- Expression initializer = varDecl.getInitializer();
+ if (infix.getOperator() != InfixExpression.Operator.NOT_EQUALS) {
+ continue;
+ }
- while (initializer instanceof ParenthesizedExpression) {
- initializer = ((ParenthesizedExpression) initializer).getExpression();
- }
+ Expression leftOperand = infix.getLeftOperand();
+ Expression rightOperand = infix.getRightOperand();
- if (initializer instanceof ConditionalExpression ternary) {
- if (ternary.getElseExpression() instanceof NullLiteral) {
- assignedVariable = varDecl;
- LOGGER.debug("Found ternary assignment: " + assignedVariable.getName());
- LOGGER.debug("Ternary condition: " + ternary.getExpression());
- }
- }
- break;
+ SimpleName varName;
+ if (rightOperand instanceof SimpleName rightVarName && leftOperand instanceof NullLiteral) {
+ varName = rightVarName;
+ } else if (leftOperand instanceof SimpleName leftVarName && rightOperand instanceof NullLiteral) {
+ varName = leftVarName;
+ } else {
+ continue;
+ }
+ if (validRefactors.get(varName.resolveBinding()) != null) {
+ LOGGER.debug("Found indirect null check in if-statement: " + condition);
+ return true;
}
- parentNode = parentNode.getParent();
}
+ LOGGER.debug("No valid refactors found for IfStatement %s", ifStmt);
+ return false;
+ }
- if (assignedVariable == null) {
- LOGGER.debug("No ternary assignment found.");
+ @Override
+ public void apply(ASTNode node, ASTRewrite rewriter) {
+ if (!(node instanceof IfStatement ifStmt)) {
return;
}
- // 2️⃣ Find the if-statement checking the assigned variable
- ASTNode current = assignedVariable.getParent();
- while (current != null) {
- if (current instanceof IfStatement ifStmt) {
- Expression condition = ifStmt.getExpression();
-
- if (condition instanceof InfixExpression infix) {
- if (infix.getOperator() == InfixExpression.Operator.NOT_EQUALS
- && infix.getLeftOperand() instanceof SimpleName) {
-
- SimpleName varName = (SimpleName) infix.getLeftOperand();
- if (varName.getIdentifier().equals(assignedVariable.getName().getIdentifier())) {
- existingIfStatement = ifStmt;
- LOGGER.debug("Found indirect null check in if-statement: {}", condition);
- break;
- }
- }
- }
- }
+ Expression ifStmtCondition = ifStmt.getExpression();
+ LOGGER.debug("Analyzing if-statement: " + ifStmtCondition);
+ List conditionFragments = Refactoring.getSubExpressions(ifStmtCondition);
- // ✅ Instead of going up the AST, we move **forward** in the block
- if (current.getParent() instanceof Block block) {
- List> statements = block.statements();
- int index = statements.indexOf(current);
-
- // Move forward to find an if-statement
- for (int i = index + 1; i < statements.size(); i++) {
- ASTNode nextNode = (ASTNode) statements.get(i);
- if (nextNode instanceof IfStatement) {
- current = nextNode;
- break;
- }
- }
- } else {
- current = current.getParent();
+ for (Expression condition : conditionFragments) {
+ // Skip non-equality check conditionals
+ if (!(condition instanceof InfixExpression infix)) {
+ continue;
}
- }
- if (existingIfStatement != null && assignedVariable != null) {
- // Retrieve initializer and ensure it's not wrapped in a ParenthesizedExpression
- Expression initializer = assignedVariable.getInitializer();
+ Expression leftOperand = infix.getLeftOperand();
+ Expression rightOperand = infix.getRightOperand();
- // ✅ Unwrap ParenthesizedExpression before proceeding
- while (initializer instanceof ParenthesizedExpression) {
- initializer = ((ParenthesizedExpression) initializer).getExpression();
+ SimpleName varName;
+ if (rightOperand instanceof SimpleName rightVarName && leftOperand instanceof NullLiteral) {
+ varName = rightVarName;
+ } else if (leftOperand instanceof SimpleName leftVarName && rightOperand instanceof NullLiteral) {
+ varName = leftVarName;
+ } else {
+ continue;
}
- // ✅ Now, safely cast to ConditionalExpression
- if (initializer instanceof ConditionalExpression ternary) {
- Expression directCheckExpr = (Expression) ASTNode.copySubtree(ast, ternary.getExpression());
+ Expression ternary = validRefactors.get(varName.resolveBinding());
- LOGGER.debug("Replacing condition: {}", existingIfStatement.getExpression());
- LOGGER.debug("New condition: {}", directCheckExpr);
+ AST ast = node.getAST();
+ ParenthesizedExpression pExpression = ast.newParenthesizedExpression();
+ pExpression.setExpression((Expression) ASTNode.copySubtree(ast, ternary));
- rewriter.replace(existingIfStatement.getExpression(), directCheckExpr, null);
- } else {
- if (initializer != null)
- LOGGER.error("Expected ConditionalExpression but found: {}",
- initializer.getClass().getSimpleName());
- }
+ LOGGER.debug("[DEBUG] Replacing Variable: " + varName);
+ LOGGER.debug("[DEBUG] New Value: " + pExpression);
+
+ rewriter.replace(condition, pExpression, null);
}
+
}
- /**
- * Helper function: Checks if an if-condition indirectly checks a variable
- * assigned via a ternary
+ /*
+ * Checks Assignment node to see if it re-assigns an existing valid refactoring,
+ * and if so removes it from validRefactors
*/
- @SuppressWarnings("unused")
- private boolean isIndirectNullCheck(Expression condition, SimpleName assignedVariable) {
- if (condition instanceof InfixExpression infixExpr) {
- return infixExpr.getOperator() == InfixExpression.Operator.NOT_EQUALS
- && infixExpr.getLeftOperand() instanceof SimpleName && ((SimpleName) infixExpr.getLeftOperand())
- .getIdentifier().equals(assignedVariable.getIdentifier());
+ private void verifyRefactors(Assignment assignmentNode) {
+ Expression lhs = assignmentNode.getLeftHandSide();
+ if (!(lhs instanceof SimpleName varName)) {
+ return;
+ }
+ if (validRefactors.get(varName.resolveBinding()) != null) {
+ validRefactors.remove(varName.resolveBinding());
}
- return false;
}
-
}
diff --git a/src/main/java/BooleanFlagRefactoring.java b/src/main/java/BooleanFlagRefactoring.java
new file mode 100644
index 0000000..891d89c
--- /dev/null
+++ b/src/main/java/BooleanFlagRefactoring.java
@@ -0,0 +1,175 @@
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.eclipse.jdt.core.dom.AST;
+import org.eclipse.jdt.core.dom.ASTNode;
+import org.eclipse.jdt.core.dom.Assignment;
+import org.eclipse.jdt.core.dom.ConditionalExpression;
+import org.eclipse.jdt.core.dom.Expression;
+import org.eclipse.jdt.core.dom.IBinding;
+import org.eclipse.jdt.core.dom.IfStatement;
+import org.eclipse.jdt.core.dom.InfixExpression;
+import org.eclipse.jdt.core.dom.NullLiteral;
+import org.eclipse.jdt.core.dom.ParenthesizedExpression;
+import org.eclipse.jdt.core.dom.PrimitiveType;
+import org.eclipse.jdt.core.dom.SimpleName;
+import org.eclipse.jdt.core.dom.VariableDeclarationFragment;
+import org.eclipse.jdt.core.dom.VariableDeclarationStatement;
+import org.eclipse.jdt.core.dom.InfixExpression.Operator;
+import org.eclipse.jdt.core.dom.rewrite.ASTRewrite;
+
+/**
+ * This class represents a refactoring in which boolean flags are replaced with
+ * explicit null checks
+ */
+public class BooleanFlagRefactoring extends Refactoring {
+ public static final String NAME = "BooleanFlagRefactoring";
+
+ /**
+ * List of variable names identified as boolean flags, along with their
+ * corresponding initializer expression
+ */
+ private final Map flagExpressions;
+
+ /** Default constructor (for RefactoringEngine integration) */
+ public BooleanFlagRefactoring() {
+ super();
+ this.flagExpressions = new HashMap<>();
+ }
+
+ @Override
+ public boolean isApplicable(ASTNode node) {
+ if (node instanceof VariableDeclarationStatement stmt) {
+ return isApplicable(stmt);
+ } else if (node instanceof IfStatement ifStmt) {
+ return isApplicable(ifStmt);
+ } else if (node instanceof Assignment assignment) {
+ checkReassignment(assignment);
+ }
+ return false;
+ }
+
+ /**
+ * Checks to see if a VariableDeclarationStatement defines a boolean flag that
+ * represents another variable's nullness
+ */
+ private boolean isApplicable(VariableDeclarationStatement stmt) {
+ boolean isBooleanDeclaration = (stmt.getType() instanceof PrimitiveType pType
+ && pType.getPrimitiveTypeCode() == PrimitiveType.BOOLEAN);
+
+ if (!isBooleanDeclaration) {
+ return false;
+ }
+
+ boolean flagFound = false;
+ AST ast = stmt.getAST();
+
+ // Search through all declared variables in declaration node for a booleanflag
+ for (VariableDeclarationFragment frag : (List) stmt.fragments()) {
+ Expression varInitializer = frag.getInitializer();
+ if (varInitializer == null) {
+ continue;
+ }
+
+ for (Expression expression : Refactoring.getSubExpressions(varInitializer)) {
+ if (expression instanceof ConditionalExpression cExpr) {
+ expression = cExpr.getExpression();
+ }
+ if (expression instanceof InfixExpression infix && isEqualityOperator(infix.getOperator())
+ && getNullComparisonVariable(infix) != null) {
+ ParenthesizedExpression copiedExpression = ast.newParenthesizedExpression();
+ copiedExpression.setExpression((Expression) ASTNode.copySubtree(ast, varInitializer));
+ flagExpressions.put(frag.getName().resolveBinding(), copiedExpression);
+ flagFound = true;
+ }
+ }
+ }
+ return flagFound;
+ }
+
+ /**
+ * Analyzes an IfStatement to see if it contains a check utilizing an identified
+ * boolean flag
+ */
+ private boolean isApplicable(IfStatement ifStmt) {
+ List exprFragments = Refactoring.getSubExpressions(ifStmt.getExpression());
+ for (Expression expr : exprFragments) {
+ if (expr instanceof InfixExpression infix && isEqualityOperator(infix.getOperator())) {
+ Expression leftOperand = infix.getLeftOperand();
+ Expression rightOperand = infix.getRightOperand();
+
+ if ((leftOperand instanceof SimpleName lhs && isFlag(lhs))
+ || (rightOperand instanceof SimpleName rhs && isFlag(rhs))) {
+ return true;
+ }
+ }
+ if (expr instanceof SimpleName varName && isFlag(varName)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private boolean isFlag(SimpleName potentialFlag) {
+ return flagExpressions.get(potentialFlag.resolveBinding()) != null;
+ }
+
+ private boolean isEqualityOperator(Operator op) {
+ return (op == Operator.NOT_EQUALS || op == Operator.EQUALS);
+ }
+
+ private SimpleName getNullComparisonVariable(InfixExpression infix) {
+ Expression leftOperand = infix.getLeftOperand();
+ Expression rightOperand = infix.getRightOperand();
+ if (leftOperand instanceof SimpleName varName && rightOperand instanceof NullLiteral) {
+ return varName;
+ } else if (rightOperand instanceof SimpleName varName && leftOperand instanceof NullLiteral) {
+ return varName;
+ }
+ return null;
+
+ }
+
+ @Override
+ public void apply(ASTNode node, ASTRewrite rewriter) {
+ if (!(node instanceof IfStatement ifStmt)) {
+ return;
+ }
+ List exprFragments = Refactoring.getSubExpressions(ifStmt.getExpression());
+ for (Expression expression : exprFragments) {
+ if (expression instanceof InfixExpression infix && isEqualityOperator(infix.getOperator())) {
+ SimpleName flagName = getNullComparisonVariable(infix);
+ apply(rewriter, flagName);
+ }
+ if (expression instanceof SimpleName flagName) {
+ apply(rewriter, flagName);
+ }
+ }
+ }
+
+ private void apply(ASTRewrite rewriter, SimpleName flagName) {
+ if (flagName == null || !isFlag(flagName)) {
+ return;
+ }
+ Expression newExpr = flagExpressions.get(flagName.resolveBinding());
+ if (newExpr != null) {
+ rewriter.replace(flagName, newExpr, null);
+ }
+
+ }
+
+ /*
+ * Checks Assignment node to see if it re-assigns an existing boolean flag, and
+ * if so removes the flag from flagExpressions
+ */
+ private void checkReassignment(Assignment assignmentNode) {
+ Expression lhs = assignmentNode.getLeftHandSide();
+ if (!(lhs instanceof SimpleName varName)) {
+ return;
+ }
+ if (isFlag(varName)) {
+ flagExpressions.remove(varName.resolveBinding());
+ }
+ }
+}
diff --git a/src/main/java/Refactoring.java b/src/main/java/Refactoring.java
index b97f898..71f8c85 100644
--- a/src/main/java/Refactoring.java
+++ b/src/main/java/Refactoring.java
@@ -35,6 +35,7 @@ public abstract class Refactoring {
* Recursively analyzes an expression and returns the boolean comparison
* subexpressions that comprise it.
*
+ *
* @param expr
* The expression to analyze
* @return A list of all subexpressions within an expreesion
diff --git a/src/main/java/RefactoringEngine.java b/src/main/java/RefactoringEngine.java
index 4cb44d4..edc4352 100644
--- a/src/main/java/RefactoringEngine.java
+++ b/src/main/java/RefactoringEngine.java
@@ -1,4 +1,5 @@
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import org.apache.logging.log4j.Logger;
@@ -30,8 +31,11 @@ public RefactoringEngine(List refactoringNames) {
switch (name) {
case AddNullCheckBeforeDereferenceRefactoring.NAME ->
refactorings.add(new AddNullCheckBeforeDereferenceRefactoring());
+ case BooleanFlagRefactoring.NAME -> refactorings.add(new BooleanFlagRefactoring());
case SentinelRefactoring.NAME -> refactorings.add(new SentinelRefactoring());
case NestedNullRefactoring.NAME -> refactorings.add(new NestedNullRefactoring());
+ case "All" -> refactorings.addAll(Arrays.asList(new AddNullCheckBeforeDereferenceRefactoring(),
+ new BooleanFlagRefactoring(), new SentinelRefactoring(), new NestedNullRefactoring()));
default -> System.err.println("Unknown refactoring: " + name);
}
diff --git a/src/main/resources/log4j2.xml b/src/main/resources/log4j2.xml
new file mode 100644
index 0000000..be144b8
--- /dev/null
+++ b/src/main/resources/log4j2.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/test/java/BooleanFlagTesting.java b/src/test/java/BooleanFlagTesting.java
new file mode 100644
index 0000000..1682cf8
--- /dev/null
+++ b/src/test/java/BooleanFlagTesting.java
@@ -0,0 +1,372 @@
+import org.junit.jupiter.api.Test;
+
+/**
+ * Class to perform JUnit tests on the BooleanFlagRefactoring refactoring module
+ *
+ * @see BooleanFlagRefactoring
+ */
+public class BooleanFlagTesting {
+
+ public void test(String input, String expectedOutput) {
+ TestingEngine.testSingleRefactoring(input, expectedOutput, "BooleanFlagRefactoring");
+ }
+
+ @Test
+ public void simpleTest() {
+ String input = """
+ public class BooleanFlagTest {
+ public void testMethod(String x) {
+ boolean xIsNull = x == null;
+ if (xIsNull) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ public class BooleanFlagTest {
+ public void testMethod(String x) {
+ boolean xIsNull = x == null;
+ if ((x == null)) {
+ ;
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void swappedSignsTest() {
+ String input = """
+ public class BooleanFlagTest {
+ public void testMethod(String x) {
+ boolean xIsNotNull = x != null;
+ if (xIsNotNull) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ public class BooleanFlagTest {
+ public void testMethod(String x) {
+ boolean xIsNotNull = x != null;
+ if ((x != null)) {
+ ;
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void inverseFlagTest1() {
+ String input = """
+ public class BooleanFlagTest {
+ public void testMethod(String x) {
+ boolean xIsNotNull = x != null;
+ if (!xIsNotNull) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ public class BooleanFlagTest {
+ public void testMethod(String x) {
+ boolean xIsNotNull = x != null;
+ if (!(x != null)) {
+ ;
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void inverseFlagTest2() {
+ String input = """
+ public class BooleanFlagTest {
+ public void testMethod(String x) {
+ boolean xIsNotNull = x != null;
+ if (!xIsNotNull) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ public class BooleanFlagTest {
+ public void testMethod(String x) {
+ boolean xIsNotNull = x != null;
+ if (!(x != null)) {
+ ;
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void andConditionTest() {
+ String input = """
+ public class BooleanFlagTest {
+ public void testMethod(String x) {
+ boolean xIsNull = x == null;
+ if (xIsNull && 1 > 0) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ public class BooleanFlagTest {
+ public void testMethod(String x) {
+ boolean xIsNull = x == null;
+ if ((x == null) && 1 > 0) {
+ ;
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void ternaryTest() {
+ String input = """
+ class TernaryBooleanFlagTest {
+ @SuppressWarnings("all")
+ void test() {
+ boolean xIsNull = (handlerMethod == null ? true : false);
+ Object exceptionHandlerObject = null;
+ Method exceptionHandlerMethod = null;
+
+ if (xIsNull) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ class TernaryBooleanFlagTest {
+ @SuppressWarnings("all")
+ void test() {
+ boolean xIsNull = (handlerMethod == null ? true : false);
+ Object exceptionHandlerObject = null;
+ Method exceptionHandlerMethod = null;
+
+ if (((handlerMethod == null ? true : false))) {
+ ;
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void inverseTernaryTest() {
+ String input = """
+ class TernaryBooleanFlagTest {
+ @SuppressWarnings("all")
+ void test() {
+ boolean xIsNotNull = (handlerMethod != null ? true : false);
+ Object exceptionHandlerObject = null;
+ Method exceptionHandlerMethod = null;
+
+ if (!xIsNotNull) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ class TernaryBooleanFlagTest {
+ @SuppressWarnings("all")
+ void test() {
+ boolean xIsNotNull = (handlerMethod != null ? true : false);
+ Object exceptionHandlerObject = null;
+ Method exceptionHandlerMethod = null;
+
+ if (!((handlerMethod != null ? true : false))) {
+ ;
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void newContainerTest() {
+ String input = """
+ public class NewContainerTest {
+ List items = Arrays.asList("Hello World");
+
+ public void test() {
+ boolean hasItems = (items != null && !items.isEmpty());
+
+ // Indirectly implies items != null
+ if (hasItems) {
+ TreeSet> set = new TreeSet<>(items);
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ public class NewContainerTest {
+ List items = Arrays.asList("Hello World");
+
+ public void test() {
+ boolean hasItems = (items != null && !items.isEmpty());
+
+ // Indirectly implies items != null
+ if (((items != null && !items.isEmpty()))) {
+ TreeSet> set = new TreeSet<>(items);
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void reassignmentTest() {
+ String input = """
+ public class NewContainerTest {
+ List items = Arrays.asList("Hello World");
+
+ public void test() {
+ boolean hasItems = (items != null && !items.isEmpty());
+ hasItems = true;
+
+ // Due to reassignment implies nothing
+ if (hasItems) {
+ TreeSet> set = new TreeSet<>(items);
+ }
+ }
+ }
+ """;
+ String expectedOutput = input; // No changes should be made
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void nextedExpressionTest() {
+ String input = """
+ public class BooleanFlagTest {
+ public void testMethod(String x) {
+ boolean xIsNull = (!(!(x == null)));
+ if (xIsNull) {
+ ;
+ }
+
+ boolean xIsNull2 = ((5 > 3) && (2 < 3 && (x == null)));
+ if (xIsNull2) {
+ ;
+ }
+
+ boolean xIsNull3 = (!((5 > 3) && !(2 < 3 && !(x == null))));
+ if (xIsNull3) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ public class BooleanFlagTest {
+ public void testMethod(String x) {
+ boolean xIsNull = (!(!(x == null)));
+ if (((!(!(x == null))))) {
+ ;
+ }
+
+ boolean xIsNull2 = ((5 > 3) && (2 < 3 && (x == null)));
+ if ((((5 > 3) && (2 < 3 && (x == null))))) {
+ ;
+ }
+
+ boolean xIsNull3 = (!((5 > 3) && !(2 < 3 && !(x == null))));
+ if (((!((5 > 3) && !(2 < 3 && !(x == null)))))) {
+ ;
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void multipleFlagsTest() {
+ String input = """
+ public class Test {
+ public void test(String a, String b) {
+ boolean aIsNull = a == null, bIsNull = b == null;
+ if (aIsNull || bIsNull) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ public class Test {
+ public void test(String a, String b) {
+ boolean aIsNull = a == null, bIsNull = b == null;
+ if ((a == null) || (b == null)) {
+ ;
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void complexMultipleFlagsTest() {
+ String input = """
+ public class Test {
+ public void test(String a, String b) {
+ boolean aIsNull = a == null, bIsNull = b == null;
+ if ((aIsNull && bIsNull) || (!aIsNull && b != null)) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ public class Test {
+ public void test(String a, String b) {
+ boolean aIsNull = a == null, bIsNull = b == null;
+ if (((a == null) && (b == null)) || (!(a == null) && b != null)) {
+ ;
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void shadowingTest() {
+ String input = """
+ public class Test {
+ boolean xIsNull;
+
+ public void test(String x) {
+ xIsNull = x == null;
+ boolean xIsNull = true;
+ if (xIsNull) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = input;
+ test(input, expectedOutput);
+ }
+}
diff --git a/src/test/java/DereferenceTesting.java b/src/test/java/DereferenceTesting.java
new file mode 100644
index 0000000..1757a0d
--- /dev/null
+++ b/src/test/java/DereferenceTesting.java
@@ -0,0 +1,225 @@
+
+import org.junit.jupiter.api.Test;
+
+/**
+ * Class to perform JUnit tests on the AddNullCheckBeforeDereferenceRefactoring
+ * refactoring module
+ */
+public class DereferenceTesting {
+
+ public void test(String input, String expectedOutput) {
+ TestingEngine.testSingleRefactoring(input, expectedOutput, AddNullCheckBeforeDereferenceRefactoring.NAME);
+ }
+
+ @Test
+ public void simpleTest() {
+ String input = """
+ public class Test {
+ private void test() {
+ Class> dependentObj = (independentObj != null ? independentObj.getDependent() : null);
+ if (dependentObj != null) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ public class Test {
+ private void test() {
+ Class> dependentObj = (independentObj != null ? independentObj.getDependent() : null);
+ if ((independentObj != null)) {
+ ;
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void swappedSignsTest() {
+ String input = """
+ public class Test {
+ private void test() {
+ Class> dependentObj = (independentObj == null ? factory.getDependent() : null);
+ if (dependentObj != null) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ public class Test {
+ private void test() {
+ Class> dependentObj = (independentObj == null ? factory.getDependent() : null);
+ if ((independentObj == null)) {
+ ;
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void prefixTest() {
+ String input = """
+ public class Test {
+ private void test() {
+ Class> dependentObj = (independentObj != null ? independentObj.getDependent() : null);
+ if (!(dependentObj != null)) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ public class Test {
+ private void test() {
+ Class> dependentObj = (independentObj != null ? independentObj.getDependent() : null);
+ if (!((independentObj != null))) {
+ ;
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void infixTest() {
+ String input = """
+ public class Test {
+ private void test() {
+ Class> dependentObj = ((independentObj != null && 5 > 3) ? independentObj.getDependent() : null);
+ if (dependentObj != null) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ public class Test {
+ private void test() {
+ Class> dependentObj = ((independentObj != null && 5 > 3) ? independentObj.getDependent() : null);
+ if (((independentObj != null && 5 > 3))) {
+ ;
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void swappedSidesTest() {
+ String input = """
+ public class Test {
+ private void test() {
+ Class> dependentObj = (independentObj != null ? null : factory.getDependent());
+ if (dependentObj != null) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ public class Test {
+ private void test() {
+ Class> dependentObj = (independentObj != null ? null : factory.getDependent());
+ if ((!(independentObj != null))) {
+ ;
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void reassignmentTest() {
+ String input = """
+ public class Test {
+ private void test() {
+ Class> dependentObj = (independentObj != null ? independentObj.getDependent() : null);
+ dependentObj = null;
+ if (dependentObj != null) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ public class Test {
+ private void test() {
+ Class> dependentObj = (independentObj != null ? independentObj.getDependent() : null);
+ dependentObj = null;
+ if (dependentObj != null) {
+ ;
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void reassignmentTest2() {
+ String input = """
+ public class Test {
+ private void test() {
+ Class> dependentObj = (independentObj != null ? independentObj.getDependent() : null);
+ dependentObj = someMethod();
+ if (dependentObj != null) {
+ ;
+ }
+ }
+ }
+ """;
+ String expectedOutput = """
+ public class Test {
+ private void test() {
+ Class> dependentObj = (independentObj != null ? independentObj.getDependent() : null);
+ dependentObj = someMethod();
+ if (dependentObj != null) {
+ ;
+ }
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+
+ @Test
+ public void shadowingTest() {
+ String input = """
+ public class Test {
+ int val = 0;
+ public void test() {
+ int val = 0;
+ String str = "Hello World";
+ if (str == null) {
+ val = -1;
+ }
+ if (this.val == 0) {
+ System.out.println("Str is not null");
+ }
+ }
+ """;
+ String expectedOutput = """
+ public class Test {
+ int val = 0;
+ public void test() {
+ int val = 0;
+ String str = "Hello World";
+ if (str == null) {
+ val = -1;
+ }
+ if (this.val == 0) {
+ System.out.println("Str is not null");
+ }
+ }
+ """;
+ test(input, expectedOutput);
+ }
+}
diff --git a/src/test/java/TestingEngine.java b/src/test/java/TestingEngine.java
index a7a2729..6036649 100644
--- a/src/test/java/TestingEngine.java
+++ b/src/test/java/TestingEngine.java
@@ -13,8 +13,9 @@ public class TestingEngine {
/**
* RefactoringEngine to use to run tests
*/
- private static RefactoringEngine fullEngine = new RefactoringEngine(Lists.newArrayList(
- AddNullCheckBeforeDereferenceRefactoring.NAME, NestedNullRefactoring.NAME, SentinelRefactoring.NAME));
+ private static RefactoringEngine fullEngine = new RefactoringEngine(
+ Lists.newArrayList(AddNullCheckBeforeDereferenceRefactoring.NAME, BooleanFlagRefactoring.NAME,
+ NestedNullRefactoring.NAME, SentinelRefactoring.NAME));
// TODO: WRITE VARIANTS FOR SUPPORTED JAVA VERSIONS
private static ASTParser parser = ASTParser.newParser(AST.getJLSLatest()); // Use appropriate