Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 61 additions & 9 deletions src/spline.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#ifndef TK_SPLINE_H
#define TK_SPLINE_H

#include <iostream>
#include <cstdio>
#include <cassert>
#include <vector>
Expand Down Expand Up @@ -120,6 +121,46 @@ class spline
// ---------------------------------------------------------------------


// sort multiple vector based on the order of the first
// -------------------------
namespace multiSort {
template <typename T, typename Compare>
std::vector<int> sort_permutation( std::vector<T> const& vec, Compare compare ) {
std::vector<int> p(vec.size());
std::iota(p.begin(), p.end(), 0);
std::sort(p.begin(), p.end(),
[&](int i, int j){ return compare(vec[i], vec[j]); });
return p;
}
template <typename T>
std::vector<int> sort_permutation( std::vector<T> const& vec ) {
return sort_permutation( vec, std::less<T>() );
}

template <typename T>
std::vector<T> apply_permutation( std::vector<T> const& vec, std::vector<int> const& p) {
std::vector<T> sorted_vec(p.size());
std::transform(p.begin(), p.end(), sorted_vec.begin(),
[&](int i){ return vec[i]; });
return sorted_vec;
}
template <typename T>
void apply_permutation( std::vector<T>& vec, std::vector<int> const& p) {
std::vector<T> sorted_vec(p.size());
std::transform(p.begin(), p.end(), sorted_vec.begin(),
[&](int i){ return vec[i]; });
vec = std::move( sorted_vec );
}

template <typename T>
void sort( std::vector<T>& v1, std::vector<T>& v2 ) {
std::vector<int> perm = sort_permutation( v1 );
apply_permutation( v1, perm );
apply_permutation( v2, perm );
}
};


// band_matrix implementation
// -------------------------

Expand Down Expand Up @@ -281,18 +322,29 @@ void spline::set_boundary(spline::bd_type left, double left_value,
}


void spline::set_points(const std::vector<double>& x,
const std::vector<double>& y, bool cubic_spline)
void spline::set_points(const std::vector<double>& xin,
const std::vector<double>& yin, bool cubic_spline)
{
assert(x.size()==y.size());
assert(x.size()>2);
m_x=x;
m_y=y;
int n=x.size();
// TODO: maybe sort x and y, rather than returning an error
assert(xin.size()==yin.size());
assert(xin.size()>2);
m_x=xin;
m_y=yin;
int n=xin.size();
// sort x and y simultaneously
for(int i=0; i<n-1; i++) {
assert(m_x[i]<m_x[i+1]);
//assert(m_x[i]<m_x[i+1]);
if(m_x[i]>m_x[i+1]) {
multiSort::sort(m_x, m_y);
break;
}
}
// check m_x to be strictly increasing
for(int i=0; i<n-1; i++)
assert(m_x[i]<m_x[i+1]);

// create reference to x, y data
const std::vector<double>& x = m_x;
const std::vector<double>& y = m_y;

if(cubic_spline==true) { // cubic spline interpolation
// setting up the matrix and right hand side of the equation system
Expand Down