diff --git a/.github/workflows/c-cpp.yml b/.github/workflows/c-cpp.yml index 759dbda98..296011e78 100644 --- a/.github/workflows/c-cpp.yml +++ b/.github/workflows/c-cpp.yml @@ -4,7 +4,7 @@ on: push: branches: [ master ] pull_request: - branches: [ master, outside_looporder, first_derivative ] + branches: [ master, outside_looporder, first_derivative, second_derivative ] workflow_dispatch: inputs: logLevel: diff --git a/.github/workflows/outside.yml b/.github/workflows/outside.yml index 168daa139..adf3e4c75 100644 --- a/.github/workflows/outside.yml +++ b/.github/workflows/outside.yml @@ -4,7 +4,7 @@ on: push: branches: [ master ] pull_request: - branches: [ master, outside_looporder ] + branches: [ master, outside_looporder, first_derivative, second_derivative ] workflow_dispatch: inputs: logLevel: diff --git a/rtlib/generic_main.cc b/rtlib/generic_main.cc index 0d33744a9..d428575d4 100644 --- a/rtlib/generic_main.cc +++ b/rtlib/generic_main.cc @@ -46,6 +46,9 @@ int main(int argc, char **argv) { std::exit(1); } gapc::class_name obj; +#ifdef SECOND_DERIVATIVE + gapc::class_name_D2 obj_D2; +#endif try { obj.init(opts); @@ -116,6 +119,15 @@ int main(int argc, char **argv) { gapc::add_event("end"); #endif +#ifdef SECOND_DERIVATIVE + obj_D2.init(opts, &obj); + gapc::add_event("start second derivative"); + obj_D2.run(); + gapc::add_event("end_computation of second derivative"); + obj_D2.report_derivative(std::cout); + gapc::add_event("end_result of second derivative"); +#endif + #ifdef STATS obj.print_stats(std::cerr); #endif diff --git a/rtlib/traces.hh b/rtlib/traces.hh index 5f949c6e7..84f5ba438 100644 --- a/rtlib/traces.hh +++ b/rtlib/traces.hh @@ -79,6 +79,7 @@ bool is_same_index(std::vector a, std::vector b) { class candidate { private: double value; + double q; std::vector sub_components; public: @@ -94,6 +95,10 @@ class candidate { this->value = value; } + void set_q(double q) { + this->q = q; + } + void add_sub_component(std::string otherNT, std::vector *indices) { sub_components.push_back({otherNT, *indices, NULL}); @@ -107,6 +112,10 @@ class candidate { return value; } + double get_q() const { + return q; + } + /* we need to normalize the individual trace values into probabilities * trace values can come in different flavors, depending on what the * user uses as scoring schema. See Algebra::check_derivative @@ -121,6 +130,16 @@ class candidate { } return res; } + + std::vector get_soft_max_hessian_candidate(double eval) const { + std::vector res; + for (std::vector::const_iterator part = this->sub_components.begin(); + part != this->sub_components.end(); ++part) { + res.push_back({std::get<0>(*part), std::get<1>(*part), + this->get_q() - (eval * this->get_value())}); + } + return res; + } // void empty() { // empty(this->value); // } @@ -143,6 +162,19 @@ std::vector normalize_traces(std::vector *tabulated, return res; } +inline +std::vector soft_max_hessian_product(std::vector *tabulated, + const std::vector &candidates, + double eval) { + std::vector, double > > res; + for (std::vector::const_iterator i = candidates.begin(); + i != candidates.end(); ++i) { + std::vector comp = (*i).get_soft_max_hessian_candidate(eval); + res.insert(res.end(), comp.begin(), comp.end()); + } + return res; +} + inline double get_trace_weights(const std::vector &traces, const std::string &to_nt, diff --git a/src/alt.cc b/src/alt.cc index 7a0c563c4..18c7fb7e4 100644 --- a/src/alt.cc +++ b/src/alt.cc @@ -47,6 +47,7 @@ #include "instance.hh" #include "statement/fn_call.hh" +#include "statement/table_decl.hh" Alt::Base::Base(Type t, const Loc &l) : @@ -56,7 +57,7 @@ Alt::Base::Base(Type t, const Loc &l) : productive(false), datatype(NULL), eliminated(false), terminal_type(false), location(l), - ret_decl(NULL), filter_guards(NULL), + ret_decl(NULL), edgeweight_decl(NULL), filter_guards(NULL), choice_fn_type_(Expr::Fn_Call::NONE), tracks_(0), track_pos_(0), is_partof_outside(false) { } @@ -1277,6 +1278,8 @@ void Alt::Base::init_ret_decl(unsigned int i, const std::string &prefix) { std::ostringstream o; o << prefix << "ret_" << i; ret_decl = new Statement::Var_Decl(datatype, new std::string(o.str())); + edgeweight_decl = new Statement::Var_Decl(datatype, + new std::string("edgeweight_" + *ret_decl->name)); } @@ -1424,7 +1427,7 @@ void Alt::Base::add_seqs(Expr::Fn_Call *fn_call, const AST &ast) const { } // TODO(sjanssen): rename outside_fn_arg to reflect actual datatype -Expr::Fn_Call *Alt::Base::inject_derivative_body(AST &ast, +Expr::Base *Alt::Base::inject_derivative_body(AST &ast, Symbol::NT &calling_nt, Alt::Base *outside_fn_arg, Expr::Base *outside_arg) { assert(outside_fn_arg); @@ -1463,12 +1466,118 @@ Expr::Fn_Call *Alt::Base::inject_derivative_body(AST &ast, mkidx->add_arg(new Expr::Const(static_cast(mkidx->exprs.size())), true); fn_call->exprs.push_back(mkidx); - // result of outside non terminal, e.g. a_0 - fn_call->exprs.push_back(outside_arg); + if (ast.current_derivative == 1) { + // result of outside non terminal, e.g. a_0 + fn_call->exprs.push_back(outside_arg); + } else if (ast.current_derivative == 2) { + // results of adjoint backward + Expr::Fn_Call *fn_e2 = dynamic_cast( + outside_fn_arg->ret_decl->rhs); + /* if rhs is a direct link to another NT, outside_fn_arg->ret_decl->rhs + * is empty and we need to resort to outside_arg */ + if (!fn_e2) { + fn_e2 = dynamic_cast(outside_arg); + } + + // results of backward + Expr::Fn_Call *fn_e1 = new Expr::Fn_Call(new std::string( + "derivative" + std::to_string(ast.current_derivative-1) + "->" + + *(fn_e2->name))); + fn_e1->add(fn_e2->exprs); + + Expr::Fn_Call *fn_q1 = new Expr::Fn_Call(new std::string(*(fn_call->name))); + fn_q1->add_arg(new std::string( + "derivative" + std::to_string(ast.current_derivative-1) + "->" + + (*outside_alt->name).substr( + sizeof(OUTSIDE_NT_PREFIX)-1, + (*outside_alt->name).length()) + "_table")); + fn_q1->is_obj = Bool(true); + fn_q1->exprs.insert( + fn_q1->exprs.end(), + std::next(fn_call->exprs.begin()), + fn_call->exprs.end()); + // don't access ret_X if it gets defined in the very same statement + if (outside_arg->is(Expr::VACC)) { + fn_q1->add_arg(outside_fn_arg->ret_decl->name); + } else { + fn_q1->add_arg(new Expr::Const(0.0)); + } + + fn_call->exprs.push_back(fn_e1); + return new Expr::Plus(fn_call, fn_q1); + } return fn_call; } + +/* iterates through the arguments of an alternative and adds one statement per + NT called to obtain derivative edge weight */ +std::list *Alt::Simple::derivative_collect_traces( + AST &ast, Symbol::NT &calling_nt) { + std::list *stmts = new std::list(); + + for (std::list::iterator i = args.begin(); + i != args.end(); ++i) { + if ((*i)->is(Fn_Arg::CONST)) { + continue; + } + Fn_Arg::Alt *fn_alt = dynamic_cast(*i); + if (fn_alt) { + Alt::Link *alt_link = dynamic_cast((*fn_alt).alt_ref()); + if (alt_link) { + Symbol::NT *alt_nt = dynamic_cast(alt_link->nt); + if (alt_nt) { + Expr::Fn_Call *fn_call = new Expr::Fn_Call( + new std::string("get_traces")); + + // access lower derivative table + fn_call->add_arg(new std::string( + "derivative" + + std::to_string(ast.current_derivative-1) + + "->" + calling_nt.table_decl->name())); + fn_call->is_obj = Bool(true); + + // index of the calling non-terminal + Fn_Def *x = new Fn_Def(); + x->add_para(calling_nt); + for (std::list::const_iterator i = x->paras.begin(); + i != x->paras.end(); ++i) { + Para_Decl::Simple *s = dynamic_cast(*i); + if (s) { + fn_call->add_arg((*s).name()); + } + } + + // add name of non-terminal that requests traces + // together with make_index, this information is used to sub-set + // stored traces to those that actually lead to this DP cell + fn_call->exprs.push_back(new Expr::Const(*alt_nt->name)); + + // calling make_index + Expr::Fn_Call *mkidx = new Expr::Fn_Call( + new std::string("*make_index")); + alt_link->add_args(mkidx); + mkidx->add_arg(new Expr::Const(static_cast(mkidx->exprs.size())), + true); + fn_call->exprs.push_back(mkidx); + + fn_call->exprs.push_back(new Expr::Const(1.0)); + + Statement::Var_Decl *q = new Statement::Var_Decl( + decl->return_type, + this->edgeweight_decl->name); + q->rhs = fn_call; + Statement::Var_Assign *stmt_ass = new Statement::Var_Assign( + *q, new Expr::Times(new Expr::Vacc(*q), fn_call)); + stmts->push_back(stmt_ass); + } + } + } + } + return stmts; +} + void Alt::Simple::init_body(AST &ast, Symbol::NT &calling_nt) { body_stmts.clear(); @@ -1504,7 +1613,7 @@ void Alt::Simple::init_body(AST &ast, Symbol::NT &calling_nt) { fn_call->exprs.push_back(c->ret_decls().front()->rhs); continue; } - if (this->get_is_partof_outside() && ast.inject_derivatives) { + if (this->get_is_partof_outside() && (ast.current_derivative > 0)) { Fn_Arg::Alt *fn_alt = dynamic_cast(*i); if (fn_alt) { Alt::Link *alt_link = dynamic_cast((*fn_alt).alt_ref()); @@ -1537,11 +1646,34 @@ void Alt::Simple::init_body(AST &ast, Symbol::NT &calling_nt) { decl->return_type, new std::string("ans")); pre_decl.clear(); pre_decl.push_back(vdecl); - if (ast.inject_derivatives && this->get_is_partof_outside() && + if ((ast.current_derivative >= 1) && this->get_is_partof_outside() && outside_fn_arg) { vdecl->rhs = inject_derivative_body(ast, calling_nt, outside_fn_arg->alt_ref(), outside_arg); + } else if ((ast.current_derivative == 2) && + !this->get_is_partof_outside() && + !outside_fn_arg) { + // obtain edge weight q + std::list *stmts_qs = derivative_collect_traces( + ast, calling_nt); + stmts->insert(stmts->end(), stmts_qs->begin(), stmts_qs->end()); + + // multiply combined q with nt result + vdecl->rhs = new Expr::Times(new Expr::Vacc(this->edgeweight_decl->name), + fn_call); +// if (ast.current_derivative == 2) { +// Expr::Fn_Call *e = new Expr::Fn_Call(Expr::Fn_Call::NOT_EMPTY); +// e->add_arg(vdecl->name); +// Statement::If *cond = new Statement::If(e); +// cond->then.push_back(vdecl); +// Statement::Fn_Call *erase = +// new Statement::Fn_Call(Statement::Fn_Call::ERASE); +// erase->add_arg(vdecl->name); +// cond->els.push_back(erase); +// +// //vdecl = cond; +// } } else { vdecl->rhs = fn_call; } @@ -1551,43 +1683,91 @@ void Alt::Simple::init_body(AST &ast, Symbol::NT &calling_nt) { fn->add_arg(*vdecl); stmts->push_back(vdecl); init_derivative_recording(ast, vdecl->name); + std::list *stmts_cmp_push = + new std::list(); Expr::Base *suchthat = suchthat_code(*vdecl); if (suchthat) { Statement::If *c = new Statement::If(suchthat); c->then.push_back(fn); c->then.insert(c->then.end(), this->derivative_statements.begin(), this->derivative_statements.end()); - stmts->push_back(c); + stmts_cmp_push->push_back(c); } else { - stmts->push_back(fn); - stmts->insert(stmts->end(), this->derivative_statements.begin(), - this->derivative_statements.end()); + stmts_cmp_push->push_back(fn); + stmts_cmp_push->insert(stmts_cmp_push->end(), + this->derivative_statements.begin(), + this->derivative_statements.end()); } + if ((ast.current_derivative == 2) && !this->get_is_partof_outside()) { + Expr::Fn_Call *e = new Expr::Fn_Call(Expr::Fn_Call::NOT_EMPTY); + e->add_arg(vdecl->name); + Statement::If *cond_edge_empty = new Statement::If(e); + cond_edge_empty->then.insert(cond_edge_empty->then.begin(), + stmts_cmp_push->begin(), + stmts_cmp_push->end()); + stmts_cmp_push->clear(); + stmts_cmp_push->push_back(cond_edge_empty); + } + stmts->insert(stmts->end(), + stmts_cmp_push->begin(), stmts_cmp_push->end()); // clear this list, as it has just been added to the statements this->derivative_statements.clear(); } else { Statement::Var_Assign *ass = new Statement::Var_Assign(*ret_decl); pre_decl.clear(); pre_decl.push_back(ret_decl); - if (ast.inject_derivatives && calling_nt.is_partof_outside && + if ((ast.current_derivative >= 1) && calling_nt.is_partof_outside && outside_fn_arg) { ass->rhs = inject_derivative_body(ast, calling_nt, outside_fn_arg->alt_ref(), outside_arg); + } else if ((ast.current_derivative == 2) && + !calling_nt.is_partof_outside && + !outside_fn_arg && !is_terminal()) { + // obtain edge weight q + std::list *stmts_qs = derivative_collect_traces( + ast, calling_nt); + stmts->insert(stmts->end(), stmts_qs->begin(), stmts_qs->end()); + // multiply combined q with nt result + ass->rhs = new Expr::Times(new Expr::Vacc(this->edgeweight_decl->name), + fn_call); } else { ass->rhs = fn_call; } // derviative statements will later be added in Symbol::NT::codegen init_derivative_recording(ast, ret_decl->name); - stmts->push_back(ass); + // helper list of statements, to allow wrapping with IF condition + // in case of second derivative generation to check if edge weight + // is empty + std::list *stmts_cmp_push = + new std::list(); + stmts_cmp_push->push_back(ass); Expr::Base *suchthat = suchthat_code(*ret_decl); if (suchthat) { Statement::If *c = new Statement::If(suchthat); - Statement::Fn_Call *e = new Statement::Fn_Call(Statement::Fn_Call::EMPTY); + Statement::Fn_Call *e = new Statement::Fn_Call( + Statement::Fn_Call::EMPTY); e->add_arg(*ret_decl); c->els.push_back(e); - stmts->push_back(c); + stmts_cmp_push->push_back(c); } + if ((ast.current_derivative == 2) && !this->get_is_partof_outside() && + !is_terminal()) { + Expr::Fn_Call *e = new Expr::Fn_Call(Expr::Fn_Call::NOT_EMPTY); + e->add_arg(this->edgeweight_decl->name); + Statement::If *cond_edge_empty = new Statement::If(e); + cond_edge_empty->then.insert(cond_edge_empty->then.begin(), + stmts_cmp_push->begin(), + stmts_cmp_push->end()); + Statement::Fn_Call *erase = new Statement::Fn_Call( + Statement::Fn_Call::EMPTY); + erase->add_arg(ret_decl->name); + cond_edge_empty->els.push_back(erase); + stmts_cmp_push->clear(); + stmts_cmp_push->push_back(cond_edge_empty); + } + stmts->insert(stmts->end(), + stmts_cmp_push->begin(), stmts_cmp_push->end()); } } @@ -1656,7 +1836,7 @@ std::list *Alt::Multi::derivatives_create_candidate() { void Alt::Base::init_derivative_recording( AST &ast, std::string *result_name) { - if (ast.inject_derivatives) { + if (ast.current_derivative > 0) { if (!this->is_partof_outside) { // test if this alternative uses sub-solutions from other non-terminals std::list *stmts_record = @@ -1667,7 +1847,17 @@ void Alt::Base::init_derivative_recording( x->add_arg(new std::string("cand")); x->is_obj = Bool(true); assert(result_name); - x->add_arg(result_name); + if (ast.current_derivative == 1) { + x->add_arg(result_name); + } else { + x->add_arg(this->edgeweight_decl->name); + + Statement::Fn_Call *y = new Statement::Fn_Call("set_q"); + y->add_arg(new std::string("cand")); + y->is_obj = Bool(true); + y->add_arg(result_name); + stmts_record->push_front(y); + } stmts_record->push_front(x); Statement::Var_Decl *candidate = new Statement::Var_Decl( @@ -1794,8 +1984,14 @@ void Alt::Simple::init_guards() { } -void Alt::Base::push_back_ret_decl() { +void Alt::Base::push_back_ret_decl(unsigned int current_derivative, + bool outside_generation) { statements.push_back(ret_decl); + if (top_level && current_derivative > 1 + && !is_partof_outside && !outside_generation) { + this->edgeweight_decl->rhs = new Expr::Const(1.0); + statements.push_back(this->edgeweight_decl); + } } @@ -1861,6 +2057,9 @@ std::list *Alt::Simple::reorder_args_cg( Symbol::NT &calling_nt) { for (std::list::iterator i = args.begin(); i != args.end(); ++i) { + if (this->top_level && (*i)->is(Fn_Arg::ALT)) { + (*i)->alt_ref()->edgeweight_decl = this->edgeweight_decl; + } (*i)->codegen(ast, calling_nt); // TODO(sjanssen): is this really the best way for nested Blocks?? if ((*i)->is(Fn_Arg::ALT) && (*i)->alt_ref()->is(Alt::BLOCK) @@ -2161,7 +2360,7 @@ void Alt::Simple::codegen(AST &ast, Symbol::NT &calling_nt) { } statements.clear(); - push_back_ret_decl(); + push_back_ret_decl(ast.current_derivative, calling_nt.is_partof_outside); std::list *stmts = &statements; init_guards(); @@ -2355,7 +2554,7 @@ void Alt::Link::codegen(AST &ast, Symbol::NT &calling_nt) { // std::cout << "link " << *name << std::endl; statements.clear(); - push_back_ret_decl(); + push_back_ret_decl(ast.current_derivative, calling_nt.is_partof_outside); std::string *s = NULL; if (nt->is(Symbol::TERMINAL)) { s = name; @@ -2398,9 +2597,12 @@ void Alt::Link::codegen(AST &ast, Symbol::NT &calling_nt) { /* In case of first derivatives, we want to normalize all result to * probabilities. Therefore, the recursion base of the outside pass must be * set to 1.0, instead of using the result of the initial inside pass, - * e.g. axiom with complete input substring */ + * e.g. axiom with complete input substring. + * It's similar for second derivatives, but since they are no probabilities + * recursion base must be 0.0 not 1.0. + * */ Expr::Base *fn_or_const = fn; - if ((ast.inject_derivatives) && + if ((ast.current_derivative >= 1) && (*calling_nt.orig_name == *ast.grammar()->axiom_name_inside)) { unsigned int lacking_complete_tracks = calling_nt.tracks(); for (std::vector >::const_iterator track = \ @@ -2413,7 +2615,11 @@ void Alt::Link::codegen(AST &ast, Symbol::NT &calling_nt) { } } if (lacking_complete_tracks == 0) { - fn_or_const = new Expr::Const(1.0); + if (ast.current_derivative == 1) { + fn_or_const = new Expr::Const(1.0); + } else if (ast.current_derivative == 2) { + fn_or_const = new Expr::Const(0.0); + } } } @@ -2424,7 +2630,8 @@ void Alt::Link::codegen(AST &ast, Symbol::NT &calling_nt) { statements.push_back(filter_guards); filter_guards->then.push_back(v); } else { - if (nt->is(Symbol::NONTERMINAL) && this->top_level && ast.inject_derivatives + if (nt->is(Symbol::NONTERMINAL) && this->top_level + && (ast.current_derivative > 0) && calling_nt.is_partof_outside) { ret_decl->rhs = inject_derivative_body(ast, calling_nt, this, fn_or_const); @@ -2456,7 +2663,7 @@ void Alt::Link::codegen(AST &ast, Symbol::NT &calling_nt) { void Alt::Block::codegen(AST &ast, Symbol::NT &calling_nt) { // std::cout << "-----------------Block " << std::endl; statements.clear(); - push_back_ret_decl(); + push_back_ret_decl(ast.current_derivative, calling_nt.is_partof_outside); Statement::Fn_Call *fn = new Statement::Fn_Call(Statement::Fn_Call::EMPTY); fn->add_arg(*ret_decl); statements.push_back(fn); diff --git a/src/alt.hh b/src/alt.hh index 4eb826e34..7b65838c8 100644 --- a/src/alt.hh +++ b/src/alt.hh @@ -144,6 +144,8 @@ class Base { public: Statement::Var_Decl *ret_decl; + // for derivative computation: stores the edge weight of lower derivatives: q + Statement::Var_Decl *edgeweight_decl; inline bool is(Type t) { return type == t; @@ -245,7 +247,8 @@ class Base { protected: Statement::If *filter_guards; - void push_back_ret_decl(); + void push_back_ret_decl(unsigned int current_derivative, + bool outside_generation); Expr::Base *suchthat_code(Statement::Var_Decl &decl) const; @@ -376,7 +379,7 @@ class Base { bool inside_end = false; void init_derivative_recording(AST &ast, std::string *result_name); - Expr::Fn_Call *inject_derivative_body(AST &ast, Symbol::NT &calling_nt, + Expr::Base *inject_derivative_body(AST &ast, Symbol::NT &calling_nt, Alt::Base *outside_fn_arg, Expr::Base *outside_arg); public: @@ -596,6 +599,10 @@ class Simple : public Base { Alt::Base* find_block(); Alt::Base *find_block_parent(const Alt::Base &block); + // generate code to obtain edge weights (q) for each rhs non-terminal + std::list *derivative_collect_traces( + AST &ast, Symbol::NT &calling_nt); + private: std::list *insert_index_stmts( std::list *stmts); diff --git a/src/ast.cc b/src/ast.cc index 5e7bdad4e..f2f3bc6e6 100644 --- a/src/ast.cc +++ b/src/ast.cc @@ -774,6 +774,43 @@ std::pair AST::split_classified(const std::string &n) { return std::make_pair(score, i); } +std::pair AST::split_instance_for_derivatives( + const std::string &n) { + Instance *i = instance(n); + if (!i) { + throw LogError("Instance does not exist."); + } + + if (!i->product->is(Product::TIMES)) { + throw LogError("Algebra product is not of type times, i.e. '*'!"); + } + if (!i->product->left()->is(Product::SINGLE)) { + throw LogError( + "Left algebra is no single algebra, but an algebra product!"); + } + if (!i->product->right()->is(Product::SINGLE)) { + throw LogError( + "Right algebra is no single algebra, but an algebra product!"); + } + + Instance *inst_firstD = new Instance( + new std::string("first derivative"), i->product->left(), grammar()); + if (!inst_firstD->product->algebra()->is_compatible(Mode::SYNOPTIC)) { + throw LogError( + "Left algebra is not synoptic, e.g. choice function is not sum."); + } + check_instances(inst_firstD); + Instance *inst_secondD = new Instance( + new std::string("second derivative"), i->product->right(), grammar()); + if (!inst_secondD->product->algebra()->is_compatible(Mode::SYNOPTIC)) { + throw LogError( + "Right algebra is not synoptic, e.g. choice function is not sum."); + } + check_instances(inst_secondD); + + return std::make_pair(inst_firstD, inst_secondD); +} + #include "unused_visitor.hh" diff --git a/src/ast.hh b/src/ast.hh index d0a8ca230..70e4ab44a 100644 --- a/src/ast.hh +++ b/src/ast.hh @@ -261,6 +261,8 @@ class AST { public: Instance *split_instance_for_backtrack(std::string &n); std::pair split_classified(const std::string &n); + std::pair split_instance_for_derivatives( + const std::string &n); void backtrack_gen(Backtrack_Base &bt); void warn_unused_fns(Instance &i); @@ -295,7 +297,18 @@ class AST { return backtrack_product; } - bool inject_derivatives = false; + // tracks which derivative is currently being generated (first or second) + unsigned int current_derivative = 0; + + /* the derivative requested by the user. + * 0 = default, i.e. no derivative generation + * 1 = generate code to compute first derivative (Jacobians), + * e.g. base pair probabilities, forward-backward, ... + * 2 = generate code to compute second derivative (Hessians), + * e.g. for machine learning, which also required generation + * of first derivatives + */ + unsigned int requested_derivative = 0; }; #endif // SRC_AST_HH_ diff --git a/src/codegen.cc b/src/codegen.cc index c1128a0cc..4604fadbd 100644 --- a/src/codegen.cc +++ b/src/codegen.cc @@ -27,7 +27,7 @@ #include "fn_def.hh" Code::Gen::Gen(AST &ast) { - if (ast.inject_derivatives) { + if (ast.current_derivative > 0) { Symbol::NT *inside_axiom = dynamic_cast( ast.grammar()->NTs[*ast.grammar()->axiom_name_inside]); assert(inside_axiom); diff --git a/src/cpp.cc b/src/cpp.cc index 44f5380cd..ceab5cda1 100644 --- a/src/cpp.cc +++ b/src/cpp.cc @@ -1606,7 +1606,12 @@ void Printer::Cpp::print_most_init(const AST &ast) { void Printer::Cpp::print_init_fn(const AST &ast) { stream << indent() << "void init("; - stream << "const gapc::Opts &opts)" << " {" << endl; + stream << "const gapc::Opts &opts"; + for (unsigned int i = 1; i < ast.current_derivative; ++i) { + stream << ", " << get_class_name_lower_derivative(ast.current_derivative, i) + << " *derivative" << std::to_string(i); + } + stream << ") {" << endl; inc_indent(); stream << indent() << "const std::vector >" @@ -1628,6 +1633,14 @@ void Printer::Cpp::print_init_fn(const AST &ast) { } } + if (ast.requested_derivative > 0) { + stream << endl; + } + for (unsigned int i = 1; i < ast.current_derivative; ++i) { + stream << indent() << "this->derivative" << std::to_string(i) + << " = derivative" << std::to_string(i) << ";" << endl; + } + dec_indent(); stream << indent() << '}' << endl << endl; } @@ -1718,8 +1731,10 @@ void Printer::Cpp::header(const AST &ast) { } if ((*ast.grammar()).is_outside()) { stream << "#define OUTSIDE\n"; - if (ast.inject_derivatives) { + if (ast.current_derivative == 1) { stream << "#define DERIVATIVES\n"; + } else if (ast.current_derivative == 2) { + stream << "#define SECOND_DERIVATIVE\n"; } } @@ -1733,6 +1748,7 @@ void Printer::Cpp::header(const AST &ast) { } imports(ast); + print_hash_decls(ast); stream << indent() << "class " << class_name << " {" << endl; @@ -1751,6 +1767,13 @@ void Printer::Cpp::header(const AST &ast) { stream << indent() << "unsigned winc;" << endl; } + /* create pointer to lower derivative results */ + for (unsigned int i = 1; i < ast.current_derivative; ++i) { + stream << indent() + << get_class_name_lower_derivative(ast.current_derivative, i) + << " *derivative" << i << ";" << endl; + } + stream << endl; print_zero_decls(*ast.grammar()); @@ -1761,7 +1784,16 @@ void Printer::Cpp::header(const AST &ast) { print_init_fn(ast); print_window_inc_fn(ast); dec_indent(); - stream << indent() << " private:" << endl; + stream << indent(); + if ((ast.current_derivative > 0) && + (ast.current_derivative < ast.requested_derivative)) { + // let higher derivatives access lower DP results, e.g. second needs first + // however, last derivative can stay private + stream << " public:"; + } else { + stream << " private:"; + } + stream << endl; inc_indent(); } @@ -2347,7 +2379,8 @@ void Printer::Cpp::print_insideoutside_report_fn( } void Printer::Cpp::print_derivative(Symbol::NT *nt) { - stream << indent() << "std::cout << \"1. derivatives for non-terminal \\\"" + stream << indent() << "std::cout << \"" << ast->current_derivative + << ". derivatives for non-terminal \\\"" << (*nt->name).substr(sizeof(OUTSIDE_NT_PREFIX)-1, (*nt->name).length()) << "\\\":\\n\";" << endl; @@ -2452,7 +2485,7 @@ void Printer::Cpp::print_run_derivative_fn(const AST &ast) { void Printer::Cpp::print_run_fn(const AST &ast) { Symbol::NT *axiom = ast.grammar()->axiom; - if (ast.inject_derivatives) { + if (ast.current_derivative > 0) { axiom = dynamic_cast( ast.grammar()->NTs[*ast.grammar()->axiom_name_inside]); } @@ -2828,13 +2861,19 @@ void Printer::Cpp::close_class() { } -void Printer::Cpp::typedefs(Code::Gen &code) { +void Printer::Cpp::typedefs(Code::Gen &code, unsigned int current_derivative) { stream << "#ifndef NO_GAPC_TYPEDEFS" << endl; stream << indent() << "namespace gapc {" << endl; inc_indent(); - stream << indent() << "typedef " << class_name << " class_name;" << endl; - stream << indent() << "typedef " << *code.return_type() - << " return_type;" << endl; + stream << indent() << "typedef " << class_name << " class_name"; + if (current_derivative >= 2) { + stream << "_D2"; + } + stream << ";" << endl; + if (current_derivative <= 1) { + stream << indent() << "typedef " << *code.return_type() + << " return_type;" << endl; + } dec_indent(); stream << indent() << '}' << endl; stream << "#endif" << endl; @@ -2879,7 +2918,7 @@ static const char deps[] = #include "prefix.hh" -void Printer::Cpp::makefile(const Options &opts) { +void Printer::Cpp::makefile(const Options &opts, const AST &ast) { stream << endl << make_comments(id_string, "#") << endl << endl; // stream << "SED = sed\n"; @@ -2902,7 +2941,15 @@ void Printer::Cpp::makefile(const Options &opts) { << "endif" << endl << endl; std::string base = opts.class_name; // basename(opts.out_file); - std::string out_file = remove_dir(opts.out_file); + std::string out_file = ""; + if (ast.requested_derivative > 1) { + for (unsigned int i = 1; i <= ast.requested_derivative; ++i) { + out_file += basename(remove_dir(opts.out_file)) + \ + "_derivative" + std::to_string(i) + ".cc "; + } + } else { + out_file = remove_dir(opts.out_file); + } std::string header_file = remove_dir(opts.header_file); stream << "CXXFILES = " << base << "_main.cc " << out_file << endl << endl; @@ -2920,9 +2967,22 @@ void Printer::Cpp::makefile(const Options &opts) { } stream << endl << endl - << base << "_main.cc : $(RTLIB)/generic_main.cc " << out_file << endl - << "\techo '#include \"" << header_file << "\"' > $@" << endl - << "\tcat $(RTLIB)/generic_main.cc >> " << base << "_main.cc" << endl + << base << "_main.cc : $(RTLIB)/generic_main.cc " << out_file << endl; + if (ast.requested_derivative > 1) { + for (unsigned int i = 1; i <= ast.requested_derivative; ++i) { + stream << "\techo '#include \"" << basename(remove_dir(opts.out_file)) + << "_derivative" << std::to_string(i) << ".hh\"' >"; + if (i > 1) { + stream << ">"; + } + stream << " $@" << endl; + } + } else { + stream << "\techo '#include " << "\"" << header_file << "\"" + << "' > $@" << endl; + } + + stream << "\tcat $(RTLIB)/generic_main.cc >> " << base << "_main.cc" << endl << endl; stream << deps << endl; stream << ".PHONY: clean" << endl << "clean:" << endl @@ -2940,7 +3000,7 @@ void Printer::Cpp::imports(const AST &ast) { return; } - if (ast.inject_derivatives) { + if (ast.current_derivative > 0) { stream << "#include \"rtlib/traces.hh\"" << endl; } @@ -2993,9 +3053,18 @@ void Printer::Cpp::imports(const AST &ast) { } } stream << endl; - stream << "#include \"rtlib/generic_opts.hh\"\n"; - stream << "#include \"rtlib/pareto_dom_sort.hh\"\n"; - stream << "#include \"rtlib/pareto_yukish_ref.hh\"\n\n"; + stream << "#include \"rtlib/generic_opts.hh\"" << endl; + stream << "#include \"rtlib/pareto_dom_sort.hh\"" << endl; + stream << "#include \"rtlib/pareto_yukish_ref.hh\"" << endl; + + /* include code of lower derivatives */ + for (unsigned int i = 1; i < ast.current_derivative; ++i) { + stream << indent() << "#include \"" + << get_class_name_lower_derivative(ast.current_derivative, i) + << ".hh\"" << endl; + } + + stream << endl; } diff --git a/src/cpp.hh b/src/cpp.hh index 67f90708e..b585f88e3 100644 --- a/src/cpp.hh +++ b/src/cpp.hh @@ -158,9 +158,25 @@ class Cpp : public Base { void print_marker_init(const AST &ast); void print_marker_clear(const AST &ast); + std::string class_name; + public: + void set_class_name(std::string class_name, + unsigned int current_derivative = 0) { + this->class_name = class_name; + if (current_derivative > 0) { + this->class_name = this->class_name + "_derivative" + \ + std::to_string(current_derivative); + } + } + std::string get_class_name_lower_derivative( + unsigned int current_derivative, unsigned int derivative) { + assert(current_derivative > 0); + assert(derivative < 10); + return class_name.substr(0, class_name.size()-1) + \ + std::to_string(derivative); + } bool in_class; - std::string class_name; Cpp() : Base(), ast(0), pure_list_type(false), in_fn_head(false), pointer_as_itr(false), @@ -244,14 +260,14 @@ class Cpp : public Base { void header_footer(const AST &ast); void footer(const AST &ast); void close_class(); - void typedefs(Code::Gen &code); + void typedefs(Code::Gen &code, unsigned int current_derivative); void prelude(const Options &opts, const AST &ast); void imports(const AST &ast); void global_constants(const AST &ast); - void makefile(const Options &opts); + void makefile(const Options &opts, const AST &ast); private: bool print_axiom_args(const AST &ast); diff --git a/src/gapc.cc b/src/gapc.cc index 3eebe958a..f36487174 100644 --- a/src/gapc.cc +++ b/src/gapc.cc @@ -368,7 +368,8 @@ class Main { // lets the AST know if code for derivative computation has // to be injected - driver.ast.inject_derivatives = opts.derivative > 0; + driver.ast.requested_derivative = opts.derivative; + driver.ast.current_derivative = opts.derivative > 0 ? 1 : 0; if (driver.is_failing()) { throw LogError("Seen parse errors."); @@ -549,8 +550,8 @@ class Main { } } - if (opts.derivative > 0) { - // if user requests derivative computation, check that user also + if (driver.ast.current_derivative == 1) { + // if user requests first derivative computation, check that user also // provided a normalization function for forward computation instance->product->algebra()->check_derivative(); } @@ -632,7 +633,8 @@ class Main { // dot-file for the grammar. This is handy if gapc modifies the original // grammar from the source file. // activate with command line argument --plot-grammar - if (opts.plot_grammar > 0) { + // (if derivatives are requested: only plot in first iteration) + if (opts.plot_grammar > 0 && driver.ast.current_derivative <= 1) { unsigned int nodeID = 1; grammar->to_dot(&nodeID, opts.plotgrammar_stream(), opts.plot_grammar); Log::instance()->normalMessage( @@ -653,16 +655,19 @@ class Main { // also writes some lines to the header file. Printer::Cpp hh(driver.ast, opts.h_stream()); hh.set_argv(argv, argc); - hh.class_name = opts.class_name; + hh.set_class_name(opts.class_name, driver.ast.current_derivative); hh.header(driver.ast); hh.begin_fwd_decls(); driver.ast.print_code(hh); + /* TODO(sjanssen): is there a nice way to add this code generating + statement only to the header file? */ + opts.h_stream() << " private:"; instance->print_code(hh); hh.footer(driver.ast); hh.end_fwd_decls(); hh.header_footer(driver.ast); if (grammar->is_outside()) { - if (driver.ast.inject_derivatives) { + if (driver.ast.current_derivative > 0) { hh.print_run_derivative_fn(driver.ast); } else { hh.print_insideoutside_report_fn(opts.outside_nt_list, driver.ast); @@ -673,7 +678,7 @@ class Main { // compile-result. Printer::Cpp cc(driver.ast, opts.stream()); cc.set_argv(argv, argc); - cc.class_name = opts.class_name; + cc.set_class_name(opts.class_name, driver.ast.current_derivative); cc.set_files(opts.in_file, opts.out_file); cc.prelude(opts, driver.ast); cc.imports(driver.ast); @@ -719,8 +724,8 @@ class Main { */ void finish() { Printer::Cpp hh(driver.ast, opts.h_stream()); - hh.class_name = opts.class_name; - hh.typedefs(code_); + hh.set_class_name(opts.class_name, driver.ast.current_derivative); + hh.typedefs(code_, driver.ast.current_derivative); } @@ -729,10 +734,10 @@ class Main { * to compile the result of this gapc compiler. * Precondition: the AST must have been created and configured. */ - void makefile() { + void makefile(const AST &ast) { Printer::Cpp mm(driver.ast, opts.m_stream()); mm.set_argv(argv, argc); - mm.makefile(opts); + mm.makefile(opts, ast); } public: @@ -749,7 +754,7 @@ class Main { * This is the entry point where the software starts. */ void runKernal() { - makefile(); + makefile(driver.ast); conv_classified_product(&opts); @@ -775,6 +780,48 @@ class Main { driver.ast.set_code_mode(mode); back(r.second, r.first); + } else if (opts.derivative > 1) { + // split algebra product "left * right" into two instances for first + // and second derivative + std::pair bothD = + driver.ast.split_instance_for_derivatives(opts.instance); + + // store user provided file name pattern for .hh and .cc + std::string orig_header_file = opts.header_file; + std::string orig_out_file = opts.out_file; + + driver.ast.current_derivative = 1; + + // prepend "_derivative1" to generated .hh and .cc file + opts.header_file = basename(orig_header_file) + "_derivative" + + std::to_string(driver.ast.current_derivative) + ".hh"; + opts.out_file = basename(orig_out_file) + "_derivative" + + std::to_string(driver.ast.current_derivative) + ".cc"; + + // start generating code for first derivative + back(bothD.first); + finish(); + // finish and close *.hh and *.cc for first derivative + delete opts.h_stream_; + opts.h_stream_ = NULL; + delete opts.out; + opts.out = NULL; + + + // start generating code for second derivative + driver.ast.current_derivative = 2; + + // prepend "_derivative2" to generated .hh and .cc file + opts.header_file = basename(orig_header_file) + "_derivative" + + std::to_string(driver.ast.current_derivative) + ".hh"; + opts.out_file = basename(orig_out_file) + "_derivative" + + std::to_string(driver.ast.current_derivative) + ".cc"; + + back(bothD.second); + + // revert opts filenames + opts.header_file = orig_header_file; + opts.out_file = orig_out_file; } else { back(); } diff --git a/src/symbol.cc b/src/symbol.cc index df0f1c547..3df18a864 100644 --- a/src/symbol.cc +++ b/src/symbol.cc @@ -1023,7 +1023,7 @@ void Symbol::NT::init_ret_stmts(Code::Mode mode, AST &ast) { tabfn->add_arg(ret); ret_stmts.push_back(tabfn); - if (ast.inject_derivatives && !this->is_partof_outside + if ((ast.current_derivative > 0) && !this->is_partof_outside && *this->name != OUTSIDE_AXIOMS) { Statement::Fn_Call *tracefn = new Statement::Fn_Call("set_traces"); tracefn->add(*table_decl); @@ -1090,8 +1090,9 @@ void Symbol::NT::init_table_decl(const AST &ast) { Tablegen tg; tg.set_window_mode(ast.window_mode); - table_decl = tg.create(*this, t, ast.code_mode() == Code::Mode::CYK, - ast.inject_derivatives && !this->is_partof_outside); + table_decl = tg.create( + *this, t, ast.code_mode() == Code::Mode::CYK, + this->is_partof_outside ? 0 : ast.current_derivative); } #include @@ -1268,7 +1269,7 @@ void Symbol::NT::codegen(AST &ast) { } stmts.push_back(ret_decl); - if (ast.inject_derivatives) { + if (ast.current_derivative > 0) { if (!this->is_partof_outside) { stmts.push_back(new Statement::Var_Decl(new ::Type::External( new std::string("NTtraces")), "candidates")); diff --git a/src/tablegen.cc b/src/tablegen.cc index a1bcb2dda..60c5b77db 100644 --- a/src/tablegen.cc +++ b/src/tablegen.cc @@ -348,7 +348,7 @@ void Tablegen::offset(size_t track_pos, itr f, const itr &e) { #include "symbol.hh" Statement::Table_Decl *Tablegen::create(Symbol::NT &nt, - std::string *name, bool cyk, bool for_derivatives) { + std::string *name, bool cyk, int forDerivative) { cyk_ = cyk; std::list ors; nt.gen_ys_guards(ors); @@ -369,7 +369,7 @@ Statement::Table_Decl *Tablegen::create(Symbol::NT &nt, offset(nt.track_pos(), nt.tables().begin(), nt.tables().end()); Fn_Def *fn_tab = gen_tab(); - Fn_Def *fn_set_traces = gen_set_traces(); + Fn_Def *fn_set_traces = gen_set_traces(forDerivative); ret_zero = new Statement::Return(new Expr::Vacc(new std::string("zero"))); offset(nt.track_pos(), nt.tables().begin(), nt.tables().end()); @@ -381,7 +381,7 @@ Statement::Table_Decl *Tablegen::create(Symbol::NT &nt, Statement::Table_Decl *td = new Statement::Table_Decl(nt, dtype, name, cyk, fn_is_tab, fn_tab, fn_set_traces, fn_get_traces, fn_get_tab, fn_size, ns); - td->for_derivatives = for_derivatives; + td->for_derivatives = forDerivative > 0; td->set_fn_untab(fn_untab); return td; } @@ -485,7 +485,7 @@ Fn_Def *Tablegen::gen_tab() { return f; } -Fn_Def *Tablegen::gen_set_traces() { +Fn_Def *Tablegen::gen_set_traces(int forDerivative) { Fn_Def *f = new Fn_Def(new Type::RealVoid(), new std::string("set_traces")); f->add_paras(paras); f->add_para(new ::Type::External(new std::string("NTtraces")), @@ -513,14 +513,19 @@ Fn_Def *Tablegen::gen_set_traces() { a->add_arg(new Expr::Less(off, new Expr::Fn_Call(new std::string("size")))); c.push_back(a); - Expr::Fn_Call *rhs_norm = new Expr::Fn_Call(new std::string( - "normalize_traces")); + std::string *fn_norm_name = new std::string("normalize_traces"); + if (forDerivative == 2) { + fn_norm_name = new std::string("soft_max_hessian_product"); + } + Expr::Fn_Call *rhs_norm = new Expr::Fn_Call(fn_norm_name); rhs_norm->add_arg(new Var_Acc::Array(new Var_Acc::Plain( new std::string("&traces")), off)); rhs_norm->add_arg(new std::string("candidates")); rhs_norm->add_arg(new std::string("e")); - rhs_norm->add_arg(new std::string("&" + *(new std::string( - FN_NAME_DERIVATIVE_NORMALIZER)))); + if (forDerivative == 1) { + rhs_norm->add_arg(new std::string("&" + *(new std::string( + FN_NAME_DERIVATIVE_NORMALIZER)))); + } Statement::Var_Assign *fn_norm = new Statement::Var_Assign( new Var_Acc::Array(new Var_Acc::Plain(new std::string("traces")), off), diff --git a/src/tablegen.hh b/src/tablegen.hh index 02a7689bf..90a8d429a 100644 --- a/src/tablegen.hh +++ b/src/tablegen.hh @@ -84,7 +84,7 @@ class Tablegen { Fn_Def *gen_is_tab(); Fn_Def *gen_untab(); Fn_Def *gen_tab(); - Fn_Def *gen_set_traces(); + Fn_Def *gen_set_traces(int forDerivative = 1); Fn_Def *gen_get_traces(); Fn_Def *gen_get_tab(); Fn_Def *gen_size(); @@ -97,7 +97,7 @@ class Tablegen { void offset(size_t track_pos, itr first, const itr &end); Statement::Table_Decl *create(Symbol::NT &nt, - std::string *name, bool cyk, bool for_derivatives); + std::string *name, bool cyk, int forDerivative); }; diff --git a/testdata/grammar_outside/alignments.gap b/testdata/grammar_outside/alignments.gap index 45faadc71..df153654e 100644 --- a/testdata/grammar_outside/alignments.gap +++ b/testdata/grammar_outside/alignments.gap @@ -37,7 +37,7 @@ algebra alg_similarity implements sig_alignments(alphabet=char, answer=int) { int Sto() { return 0; } - + int Region(, int x, ) { return x; } @@ -47,7 +47,7 @@ algebra alg_similarity implements sig_alignments(alphabet=char, answer=int) { int Region_Pr_Pr(, int x, ) { return x; } - + // this is slightly different form http://rna.informatik.uni-freiburg.de/Teaching/index.jsp?toolName=Gotoh# // as there Ins + Insx is computed for first blank, we here score Ins for first blank and Insx for all following ones int Insx(, , int x) { @@ -62,6 +62,49 @@ algebra alg_similarity implements sig_alignments(alphabet=char, answer=int) { } } +algebra alg_hessian implements sig_alignments(alphabet=char, answer=float) { + float Ins(, , float x) { + return x + -2.0; + } + float Del(, , float x) { + return x + -2.0; + } + float Ers(, , float x) { + if (a == b) { + return x +2.0; + } else { + return x +1.0; + } + } + float Sto() { + return 0.0; + } + + float Region(, float x, ) { + return x; + } + float Region_Pr(, float x, ) { + return x; + } + float Region_Pr_Pr(, float x, ) { + return x; + } + + // this is slightly different form http://rna.informatik.uni-freiburg.de/Teaching/index.jsp?toolName=Gotoh# + // as there Ins + Insx is computed for first blank, we here score Ins for first blank and Insx for all following ones + float Insx(, , float x) { + return x + -1.0; + } + float Delx(, , float x) { + return x + -1.0; + } + + choice [float] h([float] candidates) { + return list(sum(candidates)); + } +} + + algebra alg_score implements sig_alignments(alphabet=char, answer=float) { float normalize_derivative(float q, float pfunc) { return q / pfunc; @@ -107,6 +150,12 @@ algebra alg_score implements sig_alignments(alphabet=char, answer=float) { } } +algebra alg_jacobian extends alg_score { + choice [float] h([float] candidates) { + return list(sum(candidates)); + } +} + algebra alg_countmanual implements sig_alignments(alphabet=char, answer=int) { int Ins(, , int x) { @@ -171,6 +220,11 @@ grammar gra_needlemanwunsch uses sig_alignments(axiom=A) { } instance count = gra_needlemanwunsch(alg_count); +instance count_gotoh = gra_gotoh(alg_count); +instance enum_gotoh = gra_gotoh(alg_enum); instance sim_enum = gra_needlemanwunsch(alg_similarity * alg_enum); instance firstD = gra_needlemanwunsch(alg_score); instance firstD_gotoh = gra_gotoh(alg_score); + +instance bothD = gra_needlemanwunsch(alg_jacobian * alg_hessian); +instance bothD_gotoh = gra_gotoh(alg_jacobian * alg_hessian); \ No newline at end of file diff --git a/testdata/grammar_outside/elmamun_derivatives.gap b/testdata/grammar_outside/elmamun_derivatives.gap index f36db9346..ab1912760 100644 --- a/testdata/grammar_outside/elmamun_derivatives.gap +++ b/testdata/grammar_outside/elmamun_derivatives.gap @@ -66,7 +66,6 @@ algebra alg_hessians implements sig_elmamun(alphabet=char, answer=float) { } } - grammar gra_elmamun uses sig_elmamun(axiom = formula) { formula = number(INT) | add(formula, CHAR('+'), formula) diff --git a/testdata/grammar_outside/hmm_sonneregen_properEnd.gap b/testdata/grammar_outside/hmm_sonneregen_properEnd.gap index b4886a954..896fa7548 100644 --- a/testdata/grammar_outside/hmm_sonneregen_properEnd.gap +++ b/testdata/grammar_outside/hmm_sonneregen_properEnd.gap @@ -186,7 +186,6 @@ algebra alg_hessians implements sig_weather(alphabet=char, answer=float) { } - algebra alg_fwd_log implements sig_weather(alphabet=char, answer=float) { float transition_start_hoch(float transition, float emission, float x) { return log(transition) + emission + x; @@ -523,7 +522,6 @@ instance fwd_log = gra_weather(alg_fwd_log); instance fwd_neglog = gra_weather(alg_fwd_neglog); instance multviterbistates = gra_weather(alg_mult * alg_viterbi * alg_states); - instance bothD = gra_weather(alg_fwd * alg_hessians); instance bothD_log = gra_weather(alg_fwd_log * alg_hessians); instance bothD_neglog = gra_weather(alg_fwd_neglog * alg_hessians); diff --git a/testdata/grammar_outside/nodangle.gap b/testdata/grammar_outside/nodangle.gap index c706cb0dc..a53026359 100644 --- a/testdata/grammar_outside/nodangle.gap +++ b/testdata/grammar_outside/nodangle.gap @@ -108,6 +108,49 @@ algebra alg_pfunc implements sig_foldrna(alphabet = char, answer = double) { } } +// similar to alg_mfe, but datatype changed from int to double and h is sum +algebra alg_hessians implements sig_foldrna(alphabet = char, answer = double) { + double sadd(Subsequence lb, double x) { + return x + sbase_energy(); + } + double cadd(double x, double y) { + return x + y; + } + double drem(Subsequence lb, double x, Subsequence rb) { + return x + termau_energy(lb, rb); + } + double sr(Subsequence lb, double x, Subsequence rb) { + return x + sr_energy(lb, rb); + } + double hl(Subsequence lb, Subsequence r, Subsequence rb) { + return hl_energy(r); + } + double bl(Subsequence lb, Subsequence lr, double x, Subsequence rb) { + return x + bl_energy(lr, rb); + } + double br(Subsequence lb, double x, Subsequence rr, Subsequence rb) { + return x + br_energy(lb, rr); + } + double il(Subsequence lb, Subsequence lr, double x, Subsequence rr, Subsequence rb) { + return x + il_energy(lr, rr); + } + double ml(Subsequence lb, double x, Subsequence rb) { + return x + ml_energy() + ul_energy() + termau_energy(lb, rb); + } + double incl(double x) { + return x + ul_energy(); + } + double addss(double x, Subsequence r) { + return x + ss_energy(r); + } + double nil(Subsequence n) { + return 0.0; + } + choice [double] h([double] i) { + return list(sum(i)); + } +} + algebra alg_dotBracket implements sig_foldrna(alphabet = char, answer = string) { string sadd(Subsequence lb,string e) { string res; @@ -236,3 +279,4 @@ grammar gra_nodangle uses sig_foldrna(axiom = struct) { instance mfe = gra_nodangle(alg_mfe); instance pfunc = gra_nodangle(alg_pfunc); instance count = gra_nodangle(alg_count); +instance bothD = gra_nodangle(alg_pfunc * alg_hessians); diff --git a/testdata/regresstest/config b/testdata/regresstest/config index 1c3040c8a..b298d3026 100644 --- a/testdata/regresstest/config +++ b/testdata/regresstest/config @@ -498,8 +498,22 @@ check_new_old_eq nodangle.gap unused pfunc "GCaaaGC" nodanglederiv check_new_old_eq nodangle.gap unused pfunc "CCaCCaaaGGaCCaaaGGaCCaaaGGaGG" nodanglederivlong # results of the below HMM where extensively validated by Stefan via Excel and python implementations -# tests for special "normalize_derivative" algebra function CPPFLAGS_EXTRA="$DEFAULT_CPPFLAGS_EXTRA" # for negexpsum function, defined as external *.hh file +# tests for special "normalize_derivative" algebra function check_new_old_eq hmm_sonneregen_properEnd.gap unused fwd "SSRR" sonneregen1deriv # probs - sum check_new_old_eq hmm_sonneregen_properEnd.gap unused fwd_log "SSRR" sonneregen1deriv_log # log - expsum check_new_old_eq hmm_sonneregen_properEnd.gap unused fwd_neglog "SSRR" sonneregen1deriv_neglog # neglog - negexpsum + +# tests for second derivative computation +GRAMMAR=../../grammar_outside +GAPC="../../../gapc --derivative 2" +check_new_old_eq_twotrack alignments.gap unused bothD "aaaa" secondderiv "bbbb" # validated against jupyter python version +check_new_old_eq_twotrack alignments.gap unused bothD "frzeitei" nwjamie2deriv "zeit" # validated against jupyter python version +check_new_old_eq_twotrack alignments.gap unused bothD_gotoh "frzeitei" gotoh2deriv "zeit" +check_new_old_eq elmamun_derivatives.gap unused bothD "1+2*3*4+5" secondderiv +check_new_old_eq nodangle.gap unused bothD "CCaCCaaaGGaCCaaaGGaCCaaaGGaGG" secondderiv + +# despite different score normalizations for jacobians=1D, hessians=2D are the same +check_new_old_eq hmm_sonneregen_properEnd.gap unused bothD "SSRR" sonneregen2deriv_normal # probs - sum +check_new_old_eq hmm_sonneregen_properEnd.gap unused bothD_log "SSRR" sonneregen2deriv_log # probs - sum +check_new_old_eq hmm_sonneregen_properEnd.gap unused bothD_neglog "SSRR" sonneregen2deriv_neglog # probs - sum