diff --git a/src/yapsut/numba_cubic_spline.py b/src/yapsut/numba_cubic_spline.py index 172a7d3a6cdae434b48671bc7b4260621e582474..92ac4cfb609829de36d9f8f06a5d651d1316a46f 100644 --- a/src/yapsut/numba_cubic_spline.py +++ b/src/yapsut/numba_cubic_spline.py @@ -47,6 +47,29 @@ def evaluate_spline(x, y, b, c, d, x_eval): break return y_eval +@njit +def evaluate_spline_singlex(x, y, b, c, d, x_eval): + if xi <= x_eval[0] : + return y[0] + if x[-1] <= x_eval : + return y[-1] + # + i=len(x)//2 + while True : + if x_eval < x[i] : + i=i//2 + if i==0 : + return y[0] + elif x[i+1]<x_eval : + i=(i+len(x)-1)//2 + if i==len(x)-1 : + return y[-1] + else : + #if x[i] <= x_eval <= x[i+1]: + dx = x_eval - x[i] + return y[i] + b[i]*dx + c[i]*dx**2 + d[i]*dx**3 + return + @njit def evaluate_spline_derivative(x, b, c, d, x_eval): n = len(x) - 1 @@ -124,6 +147,9 @@ class NumbaNaturalCubicSpline: self.x = np.asarray(x) self.y = np.asarray(y) self.b, self.c, self.d = _compute_coeffs(self.x, self.y) + + def get_vectors(self) : + return [self.x,self.y,self.b,self.c,self.d] def __call__(self, x_eval): x_eval = np.asarray(x_eval)