@@ -5971,20 +5971,24 @@ struct test_tri : public test_case {
59715971 }
59725972};
59735973
5974- // GGML_OP_CONST
5975- struct test_const : public test_case {
5974+ // GGML_OP_FILL
5975+ struct test_fill : public test_case {
59765976 const ggml_type type;
59775977 const std::array<int64_t , 4 > ne;
59785978 float c;
59795979
59805980 std::string vars () override { return VARS_TO_STR3 (type, ne, c); }
59815981
5982- test_const (float c, ggml_type type = GGML_TYPE_F32,
5982+ test_fill (float c, ggml_type type = GGML_TYPE_F32,
59835983 std::array<int64_t , 4 > ne = { 10 , 10 , 4 , 3 })
59845984 : type(type), ne(ne), c(c) {}
59855985
59865986 ggml_tensor * build_graph (ggml_context * ctx) override {
5987- ggml_tensor * out = ggml_const (ctx, ne[0 ], ne[1 ], ne[2 ], ne[3 ], c);
5987+ ggml_tensor * a = ggml_new_tensor_4d (ctx, type, ne[0 ], ne[1 ], ne[2 ], ne[3 ]);
5988+ ggml_set_param (a);
5989+ ggml_set_name (a, " a" );
5990+
5991+ ggml_tensor * out = ggml_fill (ctx, a, c);
59885992
59895993 ggml_set_name (out, " out" );
59905994
@@ -5995,27 +5999,27 @@ struct test_const : public test_case {
59955999// GGML_OP_SOLVE_TRI
59966000struct test_solve_tri : public test_case {
59976001 const ggml_type type;
5998- const std::array<int64_t , 4 > ne ;
5999- const std::array<int64_t , 4 > ne2 ;
6002+ const std::array<int64_t , 4 > neLHS ;
6003+ const std::array<int64_t , 4 > neRHS ;
60006004
6001- std::string vars () override { return VARS_TO_STR3 (type, ne, ne2 ); }
6005+ std::string vars () override { return VARS_TO_STR3 (type, neLHS, neRHS ); }
60026006
60036007 test_solve_tri (ggml_type type = GGML_TYPE_F32,
6004- std::array<int64_t , 4 > ne = { 10 , 10 , 4 , 3 },
6005- std::array<int64_t , 4 > ne2 = { 3 , 10 , 4 , 3 }
6008+ std::array<int64_t , 4 > neLHS = { 10 , 10 , 4 , 3 },
6009+ std::array<int64_t , 4 > neRHS = { 3 , 10 , 4 , 3 }
60066010 )
6007- : type(type), ne(ne ), ne2(ne2 ) {}
6011+ : type(type), neLHS(neLHS ), neRHS(neRHS ) {}
60086012
60096013 ggml_tensor * build_graph (ggml_context * ctx) override {
6010- ggml_tensor * a = ggml_new_tensor_4d (ctx, type, ne [0 ], ne [1 ], ne [2 ], ne [3 ]);
6014+ ggml_tensor * a = ggml_new_tensor_4d (ctx, type, neLHS [0 ], neLHS [1 ], neLHS [2 ], neLHS [3 ]);
60116015 ggml_set_param (a);
60126016 ggml_set_name (a, " a" );
60136017
6014- ggml_tensor * b = ggml_new_tensor_4d (ctx, type, ne2 [0 ], ne2 [1 ], ne2 [2 ], ne2 [3 ]);
6018+ ggml_tensor * b = ggml_new_tensor_4d (ctx, type, neRHS [0 ], neRHS [1 ], neRHS [2 ], neRHS [3 ]);
60156019 ggml_set_param (b);
60166020 ggml_set_name (b, " b" );
60176021
6018- ggml_tensor * out = ggml_solve_tri (ctx, a, b);
6022+ ggml_tensor * out = ggml_solve_tri (ctx, a, b, true , true , false );
60196023 ggml_set_name (out, " out" );
60206024
60216025 return out;
@@ -6024,7 +6028,7 @@ struct test_solve_tri : public test_case {
60246028 void initialize_tensors (ggml_context * ctx) override {
60256029 for (ggml_tensor * t = ggml_get_first_tensor (ctx); t != NULL ; t = ggml_get_next_tensor (ctx, t)) {
60266030 if (strcmp (t->name , " a" ) == 0 ) {
6027- init_tensor_causal (t, 0.1 , 1 .0f );
6031+ init_tensor_tril (t, 0.1 , 1 .0f );
60286032 } else {
60296033 init_tensor_uniform (t, 0.1 , 1 .0f );
60306034 }
@@ -7528,9 +7532,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
75287532 test_cases.emplace_back (new test_tri (GGML_TRI_TYPE_UPPER));
75297533 test_cases.emplace_back (new test_tri (GGML_TRI_TYPE_UPPER_DIAG));
75307534
7531- test_cases.emplace_back (new test_const (0 .0f ));
7532- test_cases.emplace_back (new test_const (2 .0f , GGML_TYPE_F32, { 303 , 207 , 11 , 3 }));
7533- test_cases.emplace_back (new test_const (-152 .0f , GGML_TYPE_F32, { 800 , 600 , 4 , 4 }));
7535+ test_cases.emplace_back (new test_fill (0 .0f ));
7536+ test_cases.emplace_back (new test_fill (2 .0f , GGML_TYPE_F32, { 303 , 207 , 11 , 3 }));
7537+ test_cases.emplace_back (new test_fill (-152 .0f , GGML_TYPE_F32, { 800 , 600 , 4 , 4 }));
75347538
75357539 test_cases.emplace_back (new test_solve_tri ());
75367540 test_cases.emplace_back (new test_solve_tri (GGML_TYPE_F32, { 11 , 11 , 1 , 1 }, { 5 , 11 , 1 , 1 }));
0 commit comments