diff --git a/src/spline.h b/src/spline.h index 28e3dea..29434ef 100644 --- a/src/spline.h +++ b/src/spline.h @@ -27,6 +27,7 @@ #ifndef TK_SPLINE_H #define TK_SPLINE_H +#include #include #include #include @@ -120,6 +121,46 @@ class spline // --------------------------------------------------------------------- +// sort multiple vector based on the order of the first +// ------------------------- +namespace multiSort { + template + std::vector sort_permutation( std::vector const& vec, Compare compare ) { + std::vector p(vec.size()); + std::iota(p.begin(), p.end(), 0); + std::sort(p.begin(), p.end(), + [&](int i, int j){ return compare(vec[i], vec[j]); }); + return p; + } + template + std::vector sort_permutation( std::vector const& vec ) { + return sort_permutation( vec, std::less() ); + } + + template + std::vector apply_permutation( std::vector const& vec, std::vector const& p) { + std::vector sorted_vec(p.size()); + std::transform(p.begin(), p.end(), sorted_vec.begin(), + [&](int i){ return vec[i]; }); + return sorted_vec; + } + template + void apply_permutation( std::vector& vec, std::vector const& p) { + std::vector sorted_vec(p.size()); + std::transform(p.begin(), p.end(), sorted_vec.begin(), + [&](int i){ return vec[i]; }); + vec = std::move( sorted_vec ); + } + + template + void sort( std::vector& v1, std::vector& v2 ) { + std::vector perm = sort_permutation( v1 ); + apply_permutation( v1, perm ); + apply_permutation( v2, perm ); + } +}; + + // band_matrix implementation // ------------------------- @@ -281,18 +322,29 @@ void spline::set_boundary(spline::bd_type left, double left_value, } -void spline::set_points(const std::vector& x, - const std::vector& y, bool cubic_spline) +void spline::set_points(const std::vector& xin, + const std::vector& yin, bool cubic_spline) { - assert(x.size()==y.size()); - assert(x.size()>2); - m_x=x; - m_y=y; - int n=x.size(); - // TODO: maybe sort x and y, rather than returning an error + assert(xin.size()==yin.size()); + assert(xin.size()>2); + m_x=xin; + m_y=yin; + int n=xin.size(); + // sort x and y simultaneously for(int i=0; im_x[i+1]) { + multiSort::sort(m_x, m_y); + break; + } } + // check m_x to be strictly increasing + for(int i=0; i& x = m_x; + const std::vector& y = m_y; if(cubic_spline==true) { // cubic spline interpolation // setting up the matrix and right hand side of the equation system