Skip to content

Commit

Permalink
Merge pull request #1270 from stan-dev/cleanup-tests-fatal-errors
Browse files Browse the repository at this point in the history
Add missing tests and unify fatal error messages
  • Loading branch information
WardBrian authored Nov 7, 2022
2 parents e83fe4b + 6985fa0 commit dcc107a
Show file tree
Hide file tree
Showing 35 changed files with 313 additions and 135 deletions.
11 changes: 6 additions & 5 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ let transform_program (mir : Program.Typed.t)
; log_prob= log_prob'
; generate_quantities= generate_quantities' }
| _ ->
raise (Failure "Something went wrong with program transformation packing!")
Common.FatalError.fatal_error_msg
[%message "Something went wrong with program transformation packing!"]

(**
Apply the transformation to each function body and to each program block separately.
Expand All @@ -53,8 +54,8 @@ let transform_program_blockwise (mir : Program.Typed.t)
match transform fd {pattern= SList s; meta= Location_span.empty} with
| {pattern= SList l; _} -> l
| _ ->
raise
(Failure "Something went wrong with program transformation packing!")
Common.FatalError.fatal_error_msg
[%message "Something went wrong with program transformation packing!"]
in
let transformed_functions =
List.map mir.functions_block ~f:(fun fs ->
Expand Down Expand Up @@ -1202,8 +1203,8 @@ let optimize_soa (mir : Program.Typed.t) =
match transform {pattern= SList s; meta= Location_span.empty} with
| {pattern= SList (l : Stmt.Located.t list); _} -> l
| _ ->
raise
(Failure "Something went wrong with program transformation packing!")
Common.FatalError.fatal_error_msg
[%message "Something went wrong with program transformation packing!"]
in
{mir with log_prob= transform' mir.log_prob}

Expand Down
8 changes: 4 additions & 4 deletions src/analysis_and_optimization/Pedantic_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,10 @@ let list_param_dependant_fundef_cf (mir : Program.Typed.t)
Set.Poly.map dep_args ~f:(fun (loc, ix, arg_name) ->
(loc, List.nth_exn arg_exprs ix, arg_name) )
| _ ->
raise
(Failure
"In finding searching for parameter dependent functionarguments, \
mismatched function. Please report a bug.\n" ) in
Common.FatalError.fatal_error_msg
[%message
"In finding searching for parameter dependent function arguments, \
mismatched function."] in
let arg_param_deps label arg_expr =
var_deps info_map ~expr:(Some arg_expr) label (parameter_names_set mir)
in
Expand Down
2 changes: 1 addition & 1 deletion src/analysis_and_optimization/Pedantic_dist_warnings.ml
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ let constr_mismatch_warning (constr : var_constraint_named) (arg : arg_info)
let arg_fail_msg =
Printf.sprintf "Distribution %s at %s expects more arguments." name
(Location_span.to_string loc) in
raise (Failure arg_fail_msg) in
Common.FatalError.fatal_error_msg [%message arg_fail_msg] in
match v with
| Param (pname, trans), meta ->
if transform_mismatch_constraint constr.constr trans then
Expand Down
7 changes: 0 additions & 7 deletions src/frontend/Deprecation_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,6 @@ let is_deprecated_distribution name =
let rename_deprecated map name =
Map.find map name |> Option.map ~f:fst |> Option.value ~default:name

let distribution_suffix name =
let open String in
is_suffix ~suffix:"_lpdf" name
|| is_suffix ~suffix:"_lpmf" name
|| is_suffix ~suffix:"_lcdf" name
|| is_suffix ~suffix:"_lccdf" name

let userdef_distributions stmts =
let open String in
List.filter_map
Expand Down
1 change: 0 additions & 1 deletion src/frontend/Deprecation_analysis.mli
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ val update_suffix : string -> Middle.UnsizedType.t -> string
val collect_userdef_distributions :
typed_program -> Middle.UnsizedType.t String.Map.t

val distribution_suffix : string -> bool
val without_suffix : string list -> string -> string
val is_deprecated_distribution : string -> bool
val deprecated_distributions : (string * string) String.Map.t
Expand Down
3 changes: 0 additions & 3 deletions src/frontend/Environment.mli
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ type originblock =
| TParam
| Model
| GQuant
[@@deriving sexp]

(** Information available for each variable *)
type varinfo = {origin: originblock; global: bool; readonly: bool}
[@@deriving sexp]

type info =
{ type_: UnsizedType.t
Expand All @@ -25,7 +23,6 @@ type info =
| `UserDeclared of Location_span.t
| `StanMath
| `UserDefined ] }
[@@deriving sexp]

type t

Expand Down
23 changes: 0 additions & 23 deletions src/frontend/Semantic_error.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ module TypeError = struct
* UnsizedType.t list
* (UnsizedType.autodifftype * UnsizedType.t) list
* SignatureMismatch.function_mismatch
| IllTypedReduceSumGeneric of
string
* UnsizedType.t list
* (UnsizedType.autodifftype * UnsizedType.t) list
* SignatureMismatch.function_mismatch
| IllTypedVariadic of
string
* UnsizedType.t list
Expand Down Expand Up @@ -128,9 +123,6 @@ module TypeError = struct
| IllTypedReduceSum (name, arg_tys, expected_args, error) ->
SignatureMismatch.pp_signature_mismatch ppf
(name, arg_tys, ([((ReturnType UReal, expected_args), error)], false))
| IllTypedReduceSumGeneric (name, arg_tys, expected_args, error) ->
SignatureMismatch.pp_signature_mismatch ppf
(name, arg_tys, ([((ReturnType UReal, expected_args), error)], false))
| IllTypedVariadic (name, arg_tys, args, error, return_type) ->
SignatureMismatch.pp_signature_mismatch ppf
( name
Expand Down Expand Up @@ -286,7 +278,6 @@ end

module ExpressionError = struct
type t =
| InvalidMapRectFn of string
| InvalidSizeDeclRng
| InvalidRngFunction
| InvalidUnnormalizedFunction
Expand All @@ -298,11 +289,6 @@ module ExpressionError = struct
| IntTooLarge

let pp ppf = function
| InvalidMapRectFn fn_name ->
Fmt.pf ppf
"Mapped function cannot be an _rng or _lp function, found function \
name: %s"
fn_name
| InvalidSizeDeclRng ->
Fmt.pf ppf
"Random number generators are not allowed in top level size \
Expand Down Expand Up @@ -545,12 +531,6 @@ let returning_fn_expected_nonreturning_found loc name =
let illtyped_reduce_sum loc name arg_tys args error =
TypeError (loc, TypeError.IllTypedReduceSum (name, arg_tys, args, error))

let illtyped_reduce_sum_generic loc name arg_tys expected_args error =
TypeError
( loc
, TypeError.IllTypedReduceSumGeneric (name, arg_tys, expected_args, error)
)

let illtyped_variadic loc name arg_tys args fn_rt error =
TypeError (loc, TypeError.IllTypedVariadic (name, arg_tys, args, error, fn_rt))

Expand Down Expand Up @@ -615,9 +595,6 @@ let ident_not_in_scope loc name sug =
let ident_has_unnormalized_suffix loc name =
IdentifierError (loc, IdentifierError.UnnormalizedSuffix name)

let invalid_map_rect_fn loc name =
ExpressionError (loc, ExpressionError.InvalidMapRectFn name)

let invalid_decl_rng_fn loc =
ExpressionError (loc, ExpressionError.InvalidSizeDeclRng)

Expand Down
9 changes: 0 additions & 9 deletions src/frontend/Semantic_error.mli
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,6 @@ val illtyped_reduce_sum :
-> SignatureMismatch.function_mismatch
-> t

val illtyped_reduce_sum_generic :
Location_span.t
-> string
-> UnsizedType.t list
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> SignatureMismatch.function_mismatch
-> t

val ambiguous_function_promotion :
Location_span.t
-> string
Expand Down Expand Up @@ -99,7 +91,6 @@ val ident_is_model_name : Location_span.t -> string -> t
val ident_is_stanmath_name : Location_span.t -> string -> t
val ident_in_use : Location_span.t -> string -> t
val ident_not_in_scope : Location_span.t -> string -> string option -> t
val invalid_map_rect_fn : Location_span.t -> string -> t
val invalid_decl_rng_fn : Location_span.t -> t
val invalid_rng_fn : Location_span.t -> t
val invalid_unnormalized_fn : Location_span.t -> t
Expand Down
1 change: 0 additions & 1 deletion src/frontend/SignatureMismatch.mli
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ and details = private
and function_mismatch = private
| ArgError of int * type_mismatch
| ArgNumMismatch of int * int
[@@deriving sexp]

type signature_error =
(UnsizedType.returntype * (UnsizedType.autodifftype * UnsizedType.t) list)
Expand Down
6 changes: 4 additions & 2 deletions src/frontend/Typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,9 @@ let get_consistent_types ad_level type_ es =

let check_array_expr loc es =
match es with
| [] -> Semantic_error.empty_array loc |> error
| [] ->
(* NB: This is actually disallowed by parser *)
Semantic_error.empty_array loc |> error
| {emeta= {ad_level; type_; _}; _} :: _ -> (
match get_consistent_types ad_level type_ es with
| Error (ty, meta) ->
Expand Down Expand Up @@ -631,7 +633,7 @@ and check_reduce_sum ~is_cond_dist loc cf tenv id tes =
| _ ->
let expected_args, err =
basic_mismatch () |> Result.error |> Option.value_exn in
Semantic_error.illtyped_reduce_sum_generic loc id.name
Semantic_error.illtyped_reduce_sum loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args err
|> error
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/parser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ unsized_dims:
no_assign:
| UNREACHABLE
{ (* This code will never be reached *)
raise (Failure "This should be unreachable; the UNREACHABLE token should \
never be produced")
Common.FatalError.fatal_error_msg
[%message "the UNREACHABLE token should never be produced"]
}

optional_assignment(rhs):
Expand Down
8 changes: 4 additions & 4 deletions src/middle/Internal_fun.ml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ type 'expr t =
let to_string
?(expr_to_string =
fun _ ->
raise
(Failure
"Should not be parsing expression from string in function renaming"
)) x =
Common.FatalError.fatal_error_msg
[%message
"Should not be parsing expression from string in function renaming"])
x =
Sexp.to_string (sexp_of_t expr_to_string x) ^ "__"
let pp (pp_expr : 'a Fmt.t) ppf internal =
Expand Down
42 changes: 0 additions & 42 deletions src/middle/SizedType.ml
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,6 @@ let rec pp pp_e ppf = function
Fmt.(pair ~sep:comma (fun ppf st -> pp pp_e ppf st) pp_e |> brackets)
(st, expr)

let collect_exprs st =
let rec aux accu = function
| SInt | SReal | SComplex -> List.rev accu
| SVector (_, e)
|SRowVector (_, e)
|SComplexVector e
|SComplexRowVector e ->
List.rev @@ (e :: accu)
| SMatrix (_, e1, e2) | SComplexMatrix (e1, e2) ->
List.rev @@ (e1 :: e2 :: accu)
| SArray (inner, e) -> aux (e :: accu) inner in
aux [] st

let rec to_unsized = function
| SInt -> UnsizedType.UInt
| SReal -> UReal
Expand All @@ -63,25 +50,12 @@ let rec to_unsized = function
| SComplexMatrix _ -> UComplexMatrix
| SArray (t, _) -> UArray (to_unsized t)

let rec inner_type st = match st with SArray (t, _) -> inner_type t | t -> t

let rec contains_complex st =
match st with
| SComplex | SComplexVector _ | SComplexRowVector _ | SComplexMatrix _ -> true
| SArray (t, _) -> contains_complex t
| _ -> false

let rec dims_of st =
match st with
| SArray (t, _) -> dims_of t
| SMatrix (_, d1, d2) | SComplexMatrix (d1, d2) -> [d1; d2]
| SRowVector (_, dim)
|SVector (_, dim)
|SComplexRowVector dim
|SComplexVector dim ->
[dim]
| SInt | SReal | SComplex -> []

(**
Get the dimensions with respect to sizes needed for IO.
{b Note}: The main difference from get_dims is complex,
Expand Down Expand Up @@ -150,17 +124,6 @@ let%expect_test "dims" =
|> print_endline ;
[%expect {| z, x, y |}]

(**
* Return true if SizedType contains an Eigen type
*)
let rec contains_eigen_type st =
match st with
| SInt | SReal | SComplex -> false
| SVector _ | SRowVector _ | SMatrix _ | SComplexVector _
|SComplexRowVector _ | SComplexMatrix _ ->
true
| SArray (t, _) -> contains_eigen_type t

let is_complex_type st = UnsizedType.is_complex_type (to_unsized st)

(**
Expand All @@ -174,11 +137,6 @@ let rec get_mem_pattern st =
| SVector (mem, _) | SRowVector (mem, _) | SMatrix (mem, _, _) -> mem
| SArray (t, _) -> get_mem_pattern t

(**
* Return true if SizedType contains a type tagged SoA
*)
let contains_soa st = Mem_pattern.compare (get_mem_pattern st) SoA = 0

(*Given a sizedtype, demote it's mem pattern from SoA to AoS*)
let rec demote_sizedtype_mem st =
match st with
Expand Down
2 changes: 0 additions & 2 deletions src/middle/Type.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ let pp pp_e ppf = function
| Sized st -> SizedType.pp pp_e ppf st
| Unsized ust -> UnsizedType.pp ppf ust

let collect_exprs = function Sized st -> SizedType.collect_exprs st | _ -> []

let to_unsized = function
| Sized st -> SizedType.to_unsized st
| Unsized ut -> ut
7 changes: 0 additions & 7 deletions src/middle/UnsizedType.ml
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,6 @@ let rec common_type = function
| _, _ -> None

(* -- Helpers -- *)
let rec is_real_type = function
| UReal | UVector | URowVector | UMatrix -> true
| UArray x -> is_real_type x
| _ -> false

let rec is_autodiffable = function
| UReal | UVector | URowVector | UMatrix -> true
Expand Down Expand Up @@ -189,9 +185,6 @@ let is_array ut =
false
| UArray _ -> true

let return_contains_eigen_type ret =
match ret with ReturnType t -> contains_eigen_type t | Void -> false

let rec is_indexing_matrix = function
| UArray t, _ :: idcs -> is_indexing_matrix (t, idcs)
| (UMatrix | UComplexMatrix), [] -> false
Expand Down
4 changes: 3 additions & 1 deletion src/stan_math_backend/Expression_gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ let rec pp_possibly_var_decl ppf (adtype, ut, mem_pattern) =
|UComplexMatrix ->
pf ppf "%a" pp_var_decl ut
| UReal | UInt | UComplex -> pf ppf "%a" pp_unsizedtype_local (adtype, ut)
| x -> raise_s [%message (x : UnsizedType.t) "not implemented yet"]
| x ->
Common.FatalError.fatal_error_msg
[%message (x : UnsizedType.t) "not implemented yet"]

let suffix_args = function
| Fun_kind.FnRng -> ["base_rng__"]
Expand Down
36 changes: 36 additions & 0 deletions test/integration/bad/algebra_solver/bad_data_qualifer.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
functions {
vector algebra_system (data vector y,
vector theta,
array[] real x_r,
array[] int x_i) {
vector[2] f_y;
f_y[1] = y[1] - theta[1];
f_y[2] = y[2] - theta[2];
return f_y;
}
}


data {

}

transformed data {
vector[2] y;
array[0] real x_r;
array[0] real x_i;
}

parameters {
vector[2] theta_p;
real dummy_parameter;
}

transformed parameters {
vector[2] y_s_p;
y_s_p = solve_newton(algebra_system, y, theta_p, x_r, x_i, 0.01, 0.01, 10);
}

model {
dummy_parameter ~ normal(0, 1);
}
Loading

0 comments on commit dcc107a

Please sign in to comment.