Skip to content

[nullsafety] Non-nullable check warnings #12231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src-json/warning.json
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@
"parent": "WTyper",
"enabled": false
},
{
"name": "WRedundantNullCheck",
"doc": "Value can't be null, so comparison with null is excessive",
"parent": "WTyper"
},
{
"name": "WHxb",
"doc": "Hxb (either --hxb output or haxe compiler cache) related warnings"
Expand Down
2 changes: 1 addition & 1 deletion src/context/common.ml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class compiler_callbacks = object(self)
method add_after_generation (f : unit -> unit) : unit =
after_generation := f :: !after_generation

method add_null_safety_report (f : (string*pos) list -> unit) : unit =
method add_null_safety_report (f : (WarningList.warning option*string*pos) list -> unit) : unit =
null_safety_report <- f :: null_safety_report

method run handle_error r =
Expand Down
7 changes: 4 additions & 3 deletions src/macro/macroApi.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2375,9 +2375,10 @@ let macro_api ccom get_api =
);
"on_null_safety_report", vfun1 (fun f ->
let f = prepare_callback f 1 in
(ccom()).callbacks#add_null_safety_report (fun (errors:(string*pos) list) ->
let encode_item (msg,pos) =
encode_obj [("msg", encode_string msg); ("pos", encode_pos pos)]
(ccom()).callbacks#add_null_safety_report (fun (errors:(WarningList.warning option*string*pos) list) ->
let encode_item (wtype,msg,pos) =
let wtype = match wtype with | Some _ -> "warning" | None -> "error" in
encode_obj [("type", encode_string wtype); ("msg", encode_string msg); ("pos", encode_pos pos)]
in
ignore(f [encode_array (List.map encode_item errors)])
);
Expand Down
74 changes: 61 additions & 13 deletions src/typing/nullSafety.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,23 @@ open Type
type safety_message = {
sm_msg : string;
sm_pos : pos;
sm_type : WarningList.warning option
}

type safety_report = {
mutable sr_errors : safety_message list;
mutable sr_warnings: safety_message list;
}

let add_error report msg pos =
let error = { sm_msg = ("Null safety: " ^ msg); sm_pos = pos; } in
let error = { sm_type = None; sm_msg = ("Null safety: " ^ msg); sm_pos = pos; } in
if not (List.mem error report.sr_errors) then
report.sr_errors <- error :: report.sr_errors;
report.sr_errors <- error :: report.sr_errors;;

let add_warning report wtype msg pos =
let warning = { sm_type = Some wtype; sm_msg = ("Null safety: " ^ msg); sm_pos = pos; } in
if not (List.mem warning report.sr_warnings) then
report.sr_warnings <- warning :: report.sr_warnings;

type scope_type =
| STNormal
Expand Down Expand Up @@ -447,7 +454,7 @@ let rec contains_safe_meta metadata =
let safety_enabled meta =
(contains_safe_meta meta) && not (contains_unsafe_meta meta)

let safety_mode (metadata:Ast.metadata) =
let get_safety_mode (metadata:Ast.metadata) =
let rec traverse mode meta =
match mode, meta with
| Some SMOff, _
Expand Down Expand Up @@ -1053,7 +1060,6 @@ class expr_checker mode immediate_execution report =
val mutable in_closure = false
(* if this flag is `true` then spotted errors and warnings will not be reported *)
val mutable is_pretending = false
(* val mutable cnt = 0 *)
(**
Get safety mode for this expression checker
*)
Expand All @@ -1072,6 +1078,33 @@ class expr_checker mode immediate_execution report =
in
add_error report msg (get_first_valid_pos positions)
end
(**
Register a warning
*)
method warning wtype msg (positions:Globals.pos list) =
if not is_pretending then begin
let rec get_first_valid_pos positions =
match positions with
| [] -> null_pos
| p :: rest ->
if p <> null_pos then p
else get_first_valid_pos rest
in
add_warning report wtype msg (get_first_valid_pos positions)
end

method private check_binop_redundant_null_checks e =
match e.eexpr with
| TBinop ((OpEq | OpNotEq), { eexpr = TConst TNull }, expr)
| TBinop ((OpEq | OpNotEq), expr, { eexpr = TConst TNull })
| TBinop(OpAssignOp OpNullCoal, expr, _)
| TBinop (OpNullCoal, expr, _) ->
if not (is_nullable_type ~dynamic_is_nullable:true expr.etype) then
self#warning
WRedundantNullCheck
("The operand type is not nullable, so null-check should be redundant.")
[expr.epos; e.epos];
| _ -> ()
(**
Check if `e` is nullable even if the type is reported not-nullable.
Haxe type system lies sometimes.
Expand Down Expand Up @@ -1180,7 +1213,9 @@ class expr_checker mode immediate_execution report =
| TConst _ -> ()
| TLocal _ -> ()
| TArray (arr, idx) -> self#check_array_access arr idx e.epos
| TBinop (op, left_expr, right_expr) -> self#check_binop op left_expr right_expr e.epos
| TBinop (op, left_expr, right_expr) ->
self#check_binop_redundant_null_checks e;
self#check_binop op left_expr right_expr e.epos
| TField (target, access) -> self#check_field target access e.epos
| TTypeExpr _ -> ()
| TParenthesis e -> self#check_expr e
Expand Down Expand Up @@ -1539,7 +1574,7 @@ class class_checker cls immediate_execution report (main_expr : texpr option) =
object (self)
val is_safe_class = (safety_enabled cls_meta)
val mutable checker = new expr_checker SMLoose immediate_execution report
val mutable mode = None
val mutable mode : safety_mode option = None
(**
Entry point for checking a class
*)
Expand All @@ -1549,7 +1584,7 @@ class class_checker cls immediate_execution report (main_expr : texpr option) =
self#check_var_fields;
let check_field is_static f = if not (has_class_field_flag f CfPostProcessed) then begin
validate_safety_meta report f.cf_meta;
match (safety_mode (cls_meta @ f.cf_meta)) with
match (get_safety_mode (cls_meta @ f.cf_meta)) with
| SMOff -> ()
| mode ->
(match f.cf_expr with
Expand All @@ -1560,7 +1595,7 @@ class class_checker cls immediate_execution report (main_expr : texpr option) =
self#check_accessors is_static f
end in
if is_safe_class then
Option.may ((self#get_checker (safety_mode cls_meta))#check_root_expr) (TClass.get_cl_init cls);
Option.may ((self#get_checker (get_safety_mode cls_meta))#check_root_expr) (TClass.get_cl_init cls);
Option.may (check_field false) cls.cl_constructor;
List.iter (check_field false) cls.cl_ordered_fields;
List.iter (check_field true) cls.cl_ordered_statics;
Expand Down Expand Up @@ -1601,7 +1636,7 @@ class class_checker cls immediate_execution report (main_expr : texpr option) =
match mode with
| Some mode -> mode
| None ->
let m = safety_mode cls_meta in
let m = get_safety_mode cls_meta in
mode <- Some m;
m
(**
Expand Down Expand Up @@ -1784,7 +1819,10 @@ class class_checker cls immediate_execution report (main_expr : texpr option) =
*)
let run (com:Common.context) (types:module_type list) =
let report = Timer.time com.timer_ctx ["null safety"] (fun () ->
let report = { sr_errors = [] } in
let report = {
sr_errors = [];
sr_warnings = [];
} in
let immediate_execution = new immediate_execution in
let traverse module_type =
match module_type with
Expand All @@ -1798,11 +1836,21 @@ let run (com:Common.context) (types:module_type list) =
) () in
match com.callbacks#get_null_safety_report with
| [] ->
List.iter (fun err -> Common.display_error com err.sm_msg err.sm_pos) (List.rev report.sr_errors)
List.iter (fun warn ->
com.warning (Option.get warn.sm_type) [] warn.sm_msg warn.sm_pos
) (List.rev report.sr_warnings);

List.iter (fun err ->
Common.display_error com err.sm_msg err.sm_pos
) (List.rev report.sr_errors)
| callbacks ->
let warnings =
List.map (fun warn -> (warn.sm_type, warn.sm_msg, warn.sm_pos)) report.sr_warnings
in
let errors =
List.map (fun err -> (err.sm_msg, err.sm_pos)) report.sr_errors
List.map (fun err -> (err.sm_type, err.sm_msg, err.sm_pos)) report.sr_errors
in
List.iter (fun fn -> fn errors) callbacks
let all = warnings @ errors in
List.iter (fun fn -> fn all) callbacks

;;
26 changes: 20 additions & 6 deletions tests/nullsafety/src/Validator.hx
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@
import haxe.macro.Context;
import haxe.macro.Expr;

typedef SafetyMessage = {msg:String, pos:Position}
typedef ExpectedMessage = {symbol:String, pos:Position}
typedef SafetyMessage = {type:String, msg:String, pos:Position}
typedef ExpectedMessage = {type:String, symbol:String, pos:Position}
#end

class Validator {
#if macro
static var expectedErrors:Array<ExpectedMessage> = [];
static var expectedWarnings:Array<ExpectedMessage> = [];

static dynamic function onNullSafetyReport(callback:(errors:Array<SafetyMessage>)->Void):Void {
}

static public function register() {
expectedErrors = [];
expectedWarnings = [];
onNullSafetyReport = @:privateAccess Context.load("on_null_safety_report", 1);
onNullSafetyReport(validate);
}
Expand All @@ -25,7 +27,13 @@ class Validator {
if(meta.name == ':shouldFail') {
var fieldPosInfos = Context.getPosInfos(field.pos);
fieldPosInfos.min = Context.getPosInfos(meta.pos).max + 1;
expectedErrors.push({symbol: field.name, pos:Context.makePosition(fieldPosInfos)});
expectedErrors.push({type: "error", symbol: field.name, pos:Context.makePosition(fieldPosInfos)});
break;
}
if(meta.name == ':shouldWarn') {
var fieldPosInfos = Context.getPosInfos(field.pos);
fieldPosInfos.min = Context.getPosInfos(meta.pos).max + 1;
expectedWarnings.push({type: "warning", symbol: field.name, pos:Context.makePosition(fieldPosInfos)});
break;
}
}
Expand All @@ -34,7 +42,7 @@ class Validator {
}

static function validate(errors:Array<SafetyMessage>) {
var errors = check(expectedErrors.copy(), errors.copy());
var errors = check(expectedErrors.concat(expectedWarnings), errors.copy());
if(errors.ok) {
Sys.println('${errors.passed} expected errors spotted');
Sys.println('Compile-time tests passed.');
Expand All @@ -50,6 +58,7 @@ class Validator {
var actualEvent = actual[i];
var wasExpected = false;
for(expectedEvent in expected) {
if (expectedEvent.type != actualEvent.type) continue;
if(posContains(expectedEvent.pos, actualEvent.pos)) {
expected.remove(expectedEvent);
wasExpected = true;
Expand Down Expand Up @@ -85,7 +94,12 @@ class Validator {
#end

macro static public function shouldFail(expr:Expr):Expr {
expectedErrors.push({symbol:Context.getLocalMethod(), pos:expr.pos});
expectedErrors.push({type: "error", symbol:Context.getLocalMethod(), pos:expr.pos});
return expr;
}

macro static public function shouldWarn(expr:Expr):Expr {
expectedWarnings.push({type: "warning", symbol:Context.getLocalMethod(), pos:expr.pos});
return expr;
}
}
}
3 changes: 2 additions & 1 deletion tests/nullsafety/src/cases/TestLoose.hx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cases;

import Validator.shouldFail;
import Validator.shouldWarn;

typedef NotNullAnon = {
a:String
Expand Down Expand Up @@ -133,7 +134,7 @@ class TestLoose {
}

static function nullCoal_returnNull_shouldPass(token:{children:Array<Int>}):Null<Bool> {
final children = token.children ?? return null;
final children = shouldWarn(token.children ?? return null);
var i = children.length;
return null;
}
Expand Down
60 changes: 60 additions & 0 deletions tests/nullsafety/src/cases/TestNonNullable.hx
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package cases;

import Validator.shouldWarn;
import Validator.shouldFail;

typedef Data = {
var foo:String;
}

class TestNonNullable {
static function main() {
final foo = 0;
if (shouldWarn(foo) == null) {}

final dyn:Dynamic = null;
if (dyn == null) {}

final dyn:Any = 1;
if (shouldWarn(dyn) == null) {}

var data:Data = haxe.Json.parse("{}");
data.foo.length;

switch shouldWarn(data.foo) {
case null if (shouldWarn(data.foo) == null):
final v = shouldWarn(data.foo) == null;
}

final v = shouldWarn(data.foo) == null;
shouldWarn(data.foo) != null && true;
true && shouldWarn(data.foo) != null;
data.foo != null || true;
true || shouldWarn(data.foo) != null;

throw shouldWarn(data.foo) == null;

function foo():Bool {
return shouldWarn(data.foo) == null;
}

while (shouldWarn(data.foo) == null) {}

shouldWarn(data.foo) ??= "";
final foo = shouldWarn(data.foo ?? "");
if (null == shouldWarn(data.foo)) {
trace(1);
}
if (shouldWarn(data.foo) == null) {
data.foo = "default";
}
}
}

@:build(Validator.checkFields())
class BasicErrors {
@:shouldFail static var foo2:Int;
public function new() {
shouldFail(var foo:Int = null);
}
}
4 changes: 3 additions & 1 deletion tests/nullsafety/test.hxml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ cases.TestStrictThreaded
cases.TestLoose
cases.TestSafeFieldInUnsafeClass
cases.TestAbstract
cases.TestNonNullable

--macro nullSafety('cases.TestLoose', Loose)
--macro nullSafety('cases.TestStrict', Strict)
--macro nullSafety('cases.TestStrictThreaded', StrictThreaded)
--macro Validator.register()
--macro nullSafety('cases.TestNonNullable', Loose)
--macro Validator.register()
Loading