Skip to content
Draft
9 changes: 6 additions & 3 deletions lib/stormpy/pomdp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,18 @@ def create_nondeterminstic_belief_tracker(model, reduction_timeout, track_timeou
return pomdp.NondeterministicBeliefTrackerDoubleSparse(model, opts)


def create_observation_trace_unfolder(model, risk_assessment, expr_manager):
def create_observation_trace_unfolder(model, risk_assessment, expr_manager, rejection_sampling = True):
"""

:param model:
:param risk_assessment:
:param expr_manager:
:param rejection_sampling:
:return:
"""
options = pomdp.ObservationTraceUnfolderOptions()
options.rejection_sampling = rejection_sampling
if model.is_exact:
return pomdp.ObservationTraceUnfolderExact(model, risk_assessment, expr_manager)
return pomdp.ObservationTraceUnfolderExact(model, risk_assessment, expr_manager, options)
else:
return pomdp.ObservationTraceUnfolderDouble(model, risk_assessment, expr_manager)
return pomdp.ObservationTraceUnfolderDouble(model, risk_assessment, expr_manager, options)
19 changes: 13 additions & 6 deletions src/pomdp/transformations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ std::shared_ptr<storm::models::sparse::Model<storm::RationalFunction>> apply_unk
}

template<typename ValueType>
std::shared_ptr<storm::models::sparse::Mdp<ValueType>> unfold_trace(storm::models::sparse::Pomdp<ValueType> const& pomdp, std::shared_ptr<storm::expressions::ExpressionManager>& exprManager, std::vector<uint32_t> const& observationTrace, std::vector<ValueType> const& riskDef ) {
storm::pomdp::ObservationTraceUnfolder<ValueType> transformer(pomdp, exprManager);
return transformer.transform(observationTrace, riskDef);
std::shared_ptr<storm::models::sparse::Mdp<ValueType>> unfold_trace(storm::models::sparse::Pomdp<ValueType> const& pomdp, std::shared_ptr<storm::expressions::ExpressionManager>& exprManager, std::vector<uint32_t> const& observationTrace, std::vector<ValueType> const& riskDef, bool rejectionSampling=true) {
storm::pomdp::ObservationTraceUnfolderOptions options = storm::pomdp::ObservationTraceUnfolderOptions();
options.rejectionSampling = rejectionSampling;
storm::pomdp::ObservationTraceUnfolder<ValueType> transformer(pomdp, riskDef, exprManager, options);
return transformer.transform(observationTrace);
}

// STANDARD, SIMPLE_LINEAR, SIMPLE_LINEAR_INVERSE, SIMPLE_LOG, FULL
Expand All @@ -47,6 +49,11 @@ void define_transformations_nt(py::module &m) {
.value("full", storm::transformer::PomdpFscApplicationMode::FULL)
;

py::class_<storm::pomdp::ObservationTraceUnfolderOptions> unfolderOptions(m, "ObservationTraceUnfolderOptions", "Options for the ObservationTraceUnfolder");
unfolderOptions.def(py::init<>());
unfolderOptions.def_readwrite("rejection_sampling", &storm::pomdp::ObservationTraceUnfolderOptions::rejectionSampling);


}

template<typename ValueType>
Expand All @@ -55,12 +62,12 @@ void define_transformations(py::module& m, std::string const& vtSuffix) {
m.def(("_unfold_memory_" + vtSuffix).c_str(), &unfold_memory<ValueType>, "Unfold memory into a POMDP", py::arg("pomdp"), py::arg("memorystructure"), py::arg("memorylabels") = false, py::arg("keep_state_valuations")=false);
m.def(("_make_simple_"+ vtSuffix).c_str(), &make_simple<ValueType>, "Make POMDP simple", py::arg("pomdp"), py::arg("keep_state_valuations")=false);
m.def(("_apply_unknown_fsc_" + vtSuffix).c_str(), &apply_unknown_fsc<ValueType>, "Apply unknown FSC",py::arg("pomdp"), py::arg("application_mode")=storm::transformer::PomdpFscApplicationMode::SIMPLE_LINEAR);
//m.def(("_unfold_trace_" + vtSuffix).c_str(), &unfold_trace<ValueType>, "Unfold observed trace", py::arg("pomdp"), py::arg("expression_manager"),py::arg("observation_trace"), py::arg("risk_definition"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the unfold_trace method is not binded anymore, it could be removed.


py::class_<storm::pomdp::ObservationTraceUnfolder<ValueType>> unfolder(m, ("ObservationTraceUnfolder" + vtSuffix).c_str(), "Unfolds observation traces in models");
unfolder.def(py::init<storm::models::sparse::Pomdp<ValueType> const&, std::vector<ValueType> const&, std::shared_ptr<storm::expressions::ExpressionManager>&>(), py::arg("model"), py::arg("risk"), py::arg("expression_manager"));
unfolder.def(py::init<storm::models::sparse::Pomdp<ValueType> const&, std::vector<ValueType> const&, std::shared_ptr<storm::expressions::ExpressionManager>&, storm::pomdp::ObservationTraceUnfolderOptions const&>(), py::arg("model"), py::arg("risk"), py::arg("expression_manager"), py::arg("options"));
unfolder.def("is_rejection_sampling_set", &storm::pomdp::ObservationTraceUnfolder<ValueType>::isRejectionSamplingSet);
unfolder.def("transform", &storm::pomdp::ObservationTraceUnfolder<ValueType>::transform, py::arg("trace"));
unfolder.def("extend", &storm::pomdp::ObservationTraceUnfolder<ValueType>::extend, py::arg("new_observation"));
unfolder.def("extend", &storm::pomdp::ObservationTraceUnfolder<ValueType>::extend, py::arg("new_observations"));
unfolder.def("reset", &storm::pomdp::ObservationTraceUnfolder<ValueType>::reset, py::arg("new_observation"));
}

Expand Down