Skip to content

Commit 16bb25b

Browse files
authored
Merge pull request #664 from stan-dev/bugfix/elementwise-pow
Bugfix/elementwise pow and empty matrices
2 parents 3032660 + f55adde commit 16bb25b

10 files changed

Lines changed: 4634 additions & 904 deletions

File tree

src/middle/Stan_math_signatures.ml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,11 +1079,18 @@ let () =
10791079
add_unqualified ("is_nan", ReturnType UInt, [UReal]) ;
10801080
add_binary "lbeta" ;
10811081
add_binary "lchoose" ;
1082-
add_unqualified
1083-
("linspaced_array", ReturnType (UArray UReal), [UInt; UReal; UReal]) ;
1084-
add_unqualified
1085-
("linspaced_row_vector", ReturnType URowVector, [UInt; UReal; UReal]) ;
1086-
add_unqualified ("linspaced_vector", ReturnType UVector, [UInt; UReal; UReal]) ;
1082+
add_qualified
1083+
( "linspaced_array"
1084+
, ReturnType (UArray UReal)
1085+
, [(DataOnly, UInt); (DataOnly, UReal); (DataOnly, UReal)] ) ;
1086+
add_qualified
1087+
( "linspaced_row_vector"
1088+
, ReturnType URowVector
1089+
, [(DataOnly, UInt); (DataOnly, UReal); (DataOnly, UReal)] ) ;
1090+
add_qualified
1091+
( "linspaced_vector"
1092+
, ReturnType UVector
1093+
, [(DataOnly, UInt); (DataOnly, UReal); (DataOnly, UReal)] ) ;
10871094
add_unqualified ("lkj_corr_cholesky_log", ReturnType UReal, [UMatrix; UReal]) ;
10881095
add_unqualified ("lkj_corr_cholesky_lpdf", ReturnType UReal, [UMatrix; UReal]) ;
10891096
add_unqualified ("lkj_corr_cholesky_rng", ReturnType UMatrix, [UInt; UReal]) ;

src/stan_math_backend/Expression_gen.ml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,7 @@ and gen_operator_app = function
219219
fun ppf es ->
220220
pp_scalar_binary ppf "(%a@ /@ %a)" "elt_divide(@,%a,@ %a)" es
221221
| Pow -> fun ppf es -> pp_binary_f ppf "pow" es
222-
| EltPow ->
223-
fun ppf es -> pp_scalar_binary ppf "(%a@ *@ %a)" "pow(@,%a,@ %a)" es
222+
| EltPow -> fun ppf es -> pp_binary_f ppf "pow" es
224223
| Equals -> fun ppf es -> pp_binary_f ppf "logical_eq" es
225224
| NEquals -> fun ppf es -> pp_binary_f ppf "logical_neq" es
226225
| Less -> fun ppf es -> pp_binary_f ppf "logical_lt" es
@@ -358,11 +357,16 @@ and pp_user_defined_fun ppf (f, es) =
358357
and pp_compiler_internal_fn ut f ppf es =
359358
let pp_array_literal ppf es =
360359
let pp_add_method ppf () = pf ppf ")@,.add(" in
361-
pf ppf "stan::math::array_builder<%a>()@,.add(%a)@,.array()"
362-
pp_unsizedtype_local
363-
(promote_adtype es, promote_unsizedtype es)
364-
(list ~sep:pp_add_method pp_expr)
365-
es
360+
if List.length es = 0 then
361+
pf ppf "stan::math::array_builder<%a>()@,.add(0)@,.array()"
362+
pp_unsizedtype_local
363+
(promote_adtype es, promote_unsizedtype es)
364+
else
365+
pf ppf "stan::math::array_builder<%a>()@,.add(%a)@,.array()"
366+
pp_unsizedtype_local
367+
(promote_adtype es, promote_unsizedtype es)
368+
(list ~sep:pp_add_method pp_expr)
369+
es
366370
in
367371
match Internal_fun.of_string_opt f with
368372
| Some FnMakeArray -> pp_array_literal ppf es

src/stan_math_backend/Stan_math_code_gen.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,8 @@ using stan::model::index_min_max;
724724
using stan::model::index_multi;
725725
using stan::model::index_omni;
726726
using stan::model::nil_index_list;
727-
using namespace stan::math; |}
727+
using namespace stan::math;
728+
using stan::math::pow; |}
728729

729730
(** Functions needed in the model class not defined yet in stan math.
730731
FIXME: Move these to the Stan repo when these repos are joined.

test/integration/good/code-gen/cl.expected

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ using stan::model::index_min_max;
5454
using stan::model::index_multi;
5555
using stan::model::index_omni;
5656
using stan::model::nil_index_list;
57-
using namespace stan::math;
57+
using namespace stan::math;
58+
using stan::math::pow;
5859

5960
static int current_statement__ = 0;
6061
static const std::vector<string> locations_array__ = {" (found before start of program)",

test/integration/good/code-gen/cpp.expected

Lines changed: 1931 additions & 639 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)