|
a |
|
b/inmoose/utils/splines.py |
|
|
1 |
# ----------------------------------------------------------------------------- |
|
|
2 |
# Copyright (C) 2024 Maximilien Colange |
|
|
3 |
|
|
|
4 |
# This program is free software: you can redistribute it and/or modify |
|
|
5 |
# it under the terms of the GNU General Public License as published by |
|
|
6 |
# the Free Software Foundation, either version 3 of the License, or |
|
|
7 |
# (at your option) any later version. |
|
|
8 |
|
|
|
9 |
# This program is distributed in the hope that it will be useful, |
|
|
10 |
# but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
|
11 |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
|
12 |
# GNU General Public License for more details. |
|
|
13 |
|
|
|
14 |
# You should have received a copy of the GNU General Public License |
|
|
15 |
# along with this program. If not, see <https://www.gnu.org/licenses/>. |
|
|
16 |
# ----------------------------------------------------------------------------- |
|
|
17 |
|
|
|
18 |
# original code was contributed to StackOverflow: https://stackoverflow.com/questions/71550468/does-python-have-an-analogue-to-rs-splinesns |
|
|
19 |
# code was improved to better match the R code |
|
|
20 |
|
|
|
21 |
# As per StackOverflow terms of service (see https://stackoverflow.com/help/licensing), original code is licensed under CC BY SA 4.0, which is compatible with GPL3 (see https://creativecommons.org/2015/10/08/cc-by-sa-4-0-now-one-way-compatible-with-gplv3/) |
|
|
22 |
|
|
|
23 |
import numpy as np |
|
|
24 |
from scipy.interpolate import splev |
|
|
25 |
|
|
|
26 |
from ..utils import LOGGER |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
def spline_design(knots, x, order, derivs=0): |
|
|
30 |
""" |
|
|
31 |
Evaluate the design matrix for the B-splines defined by :code:`knots` at the values in :code:`x`. |
|
|
32 |
|
|
|
33 |
Arguments |
|
|
34 |
--------- |
|
|
35 |
knots : array_like |
|
|
36 |
vector of knot positions (which will be sorted increasingly if needed). |
|
|
37 |
x : array_like |
|
|
38 |
vector of values at which to evaluate the B-spline functions or |
|
|
39 |
derivatives. The values in x must be between the “inner” knots |
|
|
40 |
:code:`knots[ord]` and :code:`knots[ length(knots) - (ord-1)]` |
|
|
41 |
order : int |
|
|
42 |
a positive integer giving the order of the spline function. This is the |
|
|
43 |
number of coefficients in each piecewise polynomial segment, thus a |
|
|
44 |
cubic spline has order 4. |
|
|
45 |
derivs : array_like |
|
|
46 |
an integer vector with values between 0 and :code:`ord` - 1, |
|
|
47 |
conceptually recycled to the length of :code:`x`. The derivative of the |
|
|
48 |
given order is evaluated at the :code:`x` positions. Defaults to zero. |
|
|
49 |
|
|
|
50 |
Returns |
|
|
51 |
------- |
|
|
52 |
ndarray |
|
|
53 |
a matrix with :code:`len(x)` rows and :code:`len(knots) - ord` columns. |
|
|
54 |
The :code:`i`'th row of the matrix contains the coefficients of the |
|
|
55 |
B-splines (or the indicated derivative of the B-splines) defined by the |
|
|
56 |
knot vector and evaluated at the :code:`i`'th value of :code:`x`. Each |
|
|
57 |
B-spline is defined by a set of :code:`ord` successive knots so the |
|
|
58 |
total number of B-splines is :code:`len(knots) - ord`. |
|
|
59 |
""" |
|
|
60 |
derivs = np.asarray(derivs) |
|
|
61 |
if derivs.ndim == 0: |
|
|
62 |
der = np.repeat(derivs, len(x)) |
|
|
63 |
else: |
|
|
64 |
der = np.zeros(len(x), dtype=int) |
|
|
65 |
der[: len(derivs)] = derivs |
|
|
66 |
n_bases = len(knots) - order |
|
|
67 |
res = np.empty((len(x), n_bases), dtype=float) |
|
|
68 |
for i in range(n_bases): |
|
|
69 |
coefs = np.zeros((n_bases,)) |
|
|
70 |
coefs[i] = 1 |
|
|
71 |
for j in range(len(x)): |
|
|
72 |
res[j, i] = splev(x, (knots, coefs, order - 1), der=der[j])[j] |
|
|
73 |
return res |
|
|
74 |
|
|
|
75 |
|
|
|
76 |
class ns: |
|
|
77 |
""" |
|
|
78 |
Class storing the B-spline basis matrix for a natural cubic spline and info used to generate it. |
|
|
79 |
|
|
|
80 |
Attributes |
|
|
81 |
---------- |
|
|
82 |
knots : array_like |
|
|
83 |
breakpoints that define the spline. |
|
|
84 |
include_intercept : bool |
|
|
85 |
whether an intercept is included in the basis |
|
|
86 |
boundary_knots : array_like |
|
|
87 |
boundary points at which to impose the natural boundary conditions and |
|
|
88 |
anchor the B-spline basis. |
|
|
89 |
basis : ndarray |
|
|
90 |
a matrix of dimension :code:`(len(x), df)`, where :code:`df = |
|
|
91 |
len(knots)-1-intercept` if :code:`df` was not supplied. |
|
|
92 |
""" |
|
|
93 |
|
|
|
94 |
def __init__( |
|
|
95 |
self, x, df=None, knots=None, boundary_knots=None, include_intercept=False |
|
|
96 |
): |
|
|
97 |
""" |
|
|
98 |
Generate the B-spline basis matrix for a natural cubic spline. |
|
|
99 |
|
|
|
100 |
This function intends to provide the same functionality as R splines::ns. |
|
|
101 |
|
|
|
102 |
Arguments |
|
|
103 |
--------- |
|
|
104 |
x : array_like |
|
|
105 |
the predictor variable |
|
|
106 |
df : int, optional |
|
|
107 |
degrees of freedom. If :code:`knots` is not specified, then the |
|
|
108 |
function chooses :code:`df - 1 - intercept` knots at suitably chosen |
|
|
109 |
quantiles of :code:`x`. If :code:`None`, the number of inner knots is |
|
|
110 |
set to :code:`len(knots)`. |
|
|
111 |
knots : array_like, optional |
|
|
112 |
breakpoints that define the spline. The default is no knots; together |
|
|
113 |
with the natural boundary conditions this results in a basis for linear |
|
|
114 |
regression on :code:`x`. Typical values are the mean or median for one |
|
|
115 |
knot, quantiles for more knots. See also :code:`boundary_knots`. |
|
|
116 |
include_intercept : bool, optional |
|
|
117 |
if :code:`True`, an intercept is included in the basis; default is |
|
|
118 |
:code:`False`. |
|
|
119 |
boundary_knots : array_like, optional |
|
|
120 |
boundary points at which to impose the natural boundary conditions and |
|
|
121 |
anchor the B-spline basis (default the range of the data). If both |
|
|
122 |
:code:`knots` and :code:`boundary_knots` are supplied, the basis |
|
|
123 |
parameters do not depend on :code:`x`. Data can extend beyond |
|
|
124 |
:code:`boundary_knots`. |
|
|
125 |
|
|
|
126 |
Returns |
|
|
127 |
------- |
|
|
128 |
ndarray |
|
|
129 |
a matrix of dimension :code:`(len(x), df)`, where :code:`df = |
|
|
130 |
len(knots)-1-intercept` if :code:`df` was not supplied. |
|
|
131 |
""" |
|
|
132 |
self.include_intercept = include_intercept |
|
|
133 |
x = np.asarray(x) |
|
|
134 |
if boundary_knots is None: |
|
|
135 |
boundary_knots = [np.min(x), np.max(x)] |
|
|
136 |
outside = False |
|
|
137 |
else: |
|
|
138 |
boundary_knots = list(np.sort(boundary_knots)) |
|
|
139 |
out_left = x < boundary_knots[0] |
|
|
140 |
out_right = x > boundary_knots[1] |
|
|
141 |
outside = out_left | out_right |
|
|
142 |
self.boundary_knots = boundary_knots |
|
|
143 |
|
|
|
144 |
if df is not None and knots is None: |
|
|
145 |
nIknots = df - 1 - include_intercept |
|
|
146 |
if nIknots < 0: |
|
|
147 |
nIknots = 0 |
|
|
148 |
LOGGER.warning("df was too small, used {1+include_intercept}") |
|
|
149 |
|
|
|
150 |
if nIknots > 0: |
|
|
151 |
knots = np.linspace(0, 1, num=nIknots + 2)[1:-1] |
|
|
152 |
knots = np.quantile(x, knots) |
|
|
153 |
else: |
|
|
154 |
nIknots = len(knots) |
|
|
155 |
self.knots = knots |
|
|
156 |
|
|
|
157 |
Aknots = np.sort(np.concatenate((boundary_knots * 4, knots))) |
|
|
158 |
|
|
|
159 |
if np.any(outside): |
|
|
160 |
basis = np.empty((x.shape[0], nIknots + 4), dtype=float) |
|
|
161 |
if np.any(out_left): |
|
|
162 |
k_pivot = boundary_knots[0] |
|
|
163 |
xl = np.ones((np.sum(out_left), 2)) |
|
|
164 |
xl[:, 1] = x[out_left] - k_pivot |
|
|
165 |
tt = spline_design(Aknots, [k_pivot, k_pivot], 4, [0, 1]) |
|
|
166 |
basis[out_left, :] = xl @ tt |
|
|
167 |
if np.any(out_right): |
|
|
168 |
k_pivot = boundary_knots[1] |
|
|
169 |
xr = np.ones((np.sum(out_right), 2)) |
|
|
170 |
xr[:, 1] = x[out_right] - k_pivot |
|
|
171 |
tt = spline_design(Aknots, [k_pivot, k_pivot], 4, [0, 1]) |
|
|
172 |
basis[out_right, :] = xr @ tt |
|
|
173 |
inside = ~outside |
|
|
174 |
if np.any(inside): |
|
|
175 |
basis[inside, :] = spline_design(Aknots, x[inside], 4) |
|
|
176 |
else: |
|
|
177 |
basis = spline_design(Aknots, x, 4) |
|
|
178 |
|
|
|
179 |
const = spline_design(Aknots, boundary_knots, 4, [2, 2]) |
|
|
180 |
|
|
|
181 |
if include_intercept is False: |
|
|
182 |
basis = basis[:, 1:] |
|
|
183 |
const = const[:, 1:] |
|
|
184 |
|
|
|
185 |
qr_const = np.linalg.qr(const.T, mode="complete")[0] |
|
|
186 |
self.basis = (qr_const.T @ basis.T).T[:, 2:] |
|
|
187 |
|
|
|
188 |
def predict(self, newx): |
|
|
189 |
""" |
|
|
190 |
Evaluate the spline basis at given values |
|
|
191 |
|
|
|
192 |
Arguments |
|
|
193 |
--------- |
|
|
194 |
newx : ndarray |
|
|
195 |
new predictor variable to regenerate the spline from |
|
|
196 |
|
|
|
197 |
Returns |
|
|
198 |
------- |
|
|
199 |
ns |
|
|
200 |
a new natural spline object, evaluated at the given values |
|
|
201 |
""" |
|
|
202 |
return ns( |
|
|
203 |
newx, |
|
|
204 |
knots=self.knots, |
|
|
205 |
boundary_knots=self.boundary_knots, |
|
|
206 |
include_intercept=self.include_intercept, |
|
|
207 |
) |