Diff of /inmoose/utils/splines.py [000000] .. [ea0fd6]

Switch to unified view

a b/inmoose/utils/splines.py
1
# -----------------------------------------------------------------------------
2
# Copyright (C) 2024 Maximilien Colange
3
4
# This program is free software: you can redistribute it and/or modify
5
# it under the terms of the GNU General Public License as published by
6
# the Free Software Foundation, either version 3 of the License, or
7
# (at your option) any later version.
8
9
# This program is distributed in the hope that it will be useful,
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
# GNU General Public License for more details.
13
14
# You should have received a copy of the GNU General Public License
15
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
16
# -----------------------------------------------------------------------------
17
18
# original code was contributed to StackOverflow: https://stackoverflow.com/questions/71550468/does-python-have-an-analogue-to-rs-splinesns
19
# code was improved to better match the R code
20
21
# As per StackOverflow terms of service (see https://stackoverflow.com/help/licensing), original code is licensed under CC BY SA 4.0, which is compatible with GPL3 (see https://creativecommons.org/2015/10/08/cc-by-sa-4-0-now-one-way-compatible-with-gplv3/)
22
23
import numpy as np
24
from scipy.interpolate import splev
25
26
from ..utils import LOGGER
27
28
29
def spline_design(knots, x, order, derivs=0):
30
    """
31
    Evaluate the design matrix for the B-splines defined by :code:`knots` at the values in :code:`x`.
32
33
    Arguments
34
    ---------
35
    knots : array_like
36
        vector of knot positions (which will be sorted increasingly if needed).
37
    x : array_like
38
        vector of values at which to evaluate the B-spline functions or
39
        derivatives. The values in x must be between the “inner” knots
40
        :code:`knots[ord]` and :code:`knots[ length(knots) - (ord-1)]`
41
    order : int
42
        a positive integer giving the order of the spline function. This is the
43
        number of coefficients in each piecewise polynomial segment, thus a
44
        cubic spline has order 4.
45
    derivs : array_like
46
        an integer vector with values between 0 and :code:`ord` - 1,
47
        conceptually recycled to the length of :code:`x`. The derivative of the
48
        given order is evaluated at the :code:`x` positions. Defaults to zero.
49
50
    Returns
51
    -------
52
    ndarray
53
        a matrix with :code:`len(x)` rows and :code:`len(knots) - ord` columns.
54
        The :code:`i`'th row of the matrix contains the coefficients of the
55
        B-splines (or the indicated derivative of the B-splines) defined by the
56
        knot vector and evaluated at the :code:`i`'th value of :code:`x`. Each
57
        B-spline is defined by a set of :code:`ord` successive knots so the
58
        total number of B-splines is :code:`len(knots) - ord`.
59
    """
60
    derivs = np.asarray(derivs)
61
    if derivs.ndim == 0:
62
        der = np.repeat(derivs, len(x))
63
    else:
64
        der = np.zeros(len(x), dtype=int)
65
        der[: len(derivs)] = derivs
66
    n_bases = len(knots) - order
67
    res = np.empty((len(x), n_bases), dtype=float)
68
    for i in range(n_bases):
69
        coefs = np.zeros((n_bases,))
70
        coefs[i] = 1
71
        for j in range(len(x)):
72
            res[j, i] = splev(x, (knots, coefs, order - 1), der=der[j])[j]
73
    return res
74
75
76
class ns:
77
    """
78
    Class storing the B-spline basis matrix for a natural cubic spline and info used to generate it.
79
80
    Attributes
81
    ----------
82
    knots : array_like
83
        breakpoints that define the spline.
84
    include_intercept : bool
85
        whether an intercept is included in the basis
86
    boundary_knots : array_like
87
        boundary points at which to impose the natural boundary conditions and
88
        anchor the B-spline basis.
89
    basis : ndarray
90
        a matrix of dimension :code:`(len(x), df)`, where :code:`df =
91
        len(knots)-1-intercept` if :code:`df` was not supplied.
92
    """
93
94
    def __init__(
95
        self, x, df=None, knots=None, boundary_knots=None, include_intercept=False
96
    ):
97
        """
98
        Generate the B-spline basis matrix for a natural cubic spline.
99
100
        This function intends to provide the same functionality as R splines::ns.
101
102
        Arguments
103
        ---------
104
        x : array_like
105
            the predictor variable
106
        df : int, optional
107
            degrees of freedom. If :code:`knots` is not specified, then the
108
            function chooses :code:`df - 1 - intercept` knots at suitably chosen
109
            quantiles of :code:`x`. If :code:`None`, the number of inner knots is
110
            set to :code:`len(knots)`.
111
        knots : array_like, optional
112
            breakpoints that define the spline. The default is no knots; together
113
            with the natural boundary conditions this results in a basis for linear
114
            regression on :code:`x`. Typical values are the mean or median for one
115
            knot, quantiles for more knots. See also :code:`boundary_knots`.
116
        include_intercept : bool, optional
117
            if :code:`True`, an intercept is included in the basis; default is
118
            :code:`False`.
119
        boundary_knots : array_like, optional
120
            boundary points at which to impose the natural boundary conditions and
121
            anchor the B-spline basis (default the range of the data). If both
122
            :code:`knots` and :code:`boundary_knots` are supplied, the basis
123
            parameters do not depend on :code:`x`. Data can extend beyond
124
            :code:`boundary_knots`.
125
126
        Returns
127
        -------
128
        ndarray
129
            a matrix of dimension :code:`(len(x), df)`, where :code:`df =
130
            len(knots)-1-intercept` if :code:`df` was not supplied.
131
        """
132
        self.include_intercept = include_intercept
133
        x = np.asarray(x)
134
        if boundary_knots is None:
135
            boundary_knots = [np.min(x), np.max(x)]
136
            outside = False
137
        else:
138
            boundary_knots = list(np.sort(boundary_knots))
139
            out_left = x < boundary_knots[0]
140
            out_right = x > boundary_knots[1]
141
            outside = out_left | out_right
142
        self.boundary_knots = boundary_knots
143
144
        if df is not None and knots is None:
145
            nIknots = df - 1 - include_intercept
146
            if nIknots < 0:
147
                nIknots = 0
148
                LOGGER.warning("df was too small, used {1+include_intercept}")
149
150
            if nIknots > 0:
151
                knots = np.linspace(0, 1, num=nIknots + 2)[1:-1]
152
                knots = np.quantile(x, knots)
153
        else:
154
            nIknots = len(knots)
155
        self.knots = knots
156
157
        Aknots = np.sort(np.concatenate((boundary_knots * 4, knots)))
158
159
        if np.any(outside):
160
            basis = np.empty((x.shape[0], nIknots + 4), dtype=float)
161
            if np.any(out_left):
162
                k_pivot = boundary_knots[0]
163
                xl = np.ones((np.sum(out_left), 2))
164
                xl[:, 1] = x[out_left] - k_pivot
165
                tt = spline_design(Aknots, [k_pivot, k_pivot], 4, [0, 1])
166
                basis[out_left, :] = xl @ tt
167
            if np.any(out_right):
168
                k_pivot = boundary_knots[1]
169
                xr = np.ones((np.sum(out_right), 2))
170
                xr[:, 1] = x[out_right] - k_pivot
171
                tt = spline_design(Aknots, [k_pivot, k_pivot], 4, [0, 1])
172
                basis[out_right, :] = xr @ tt
173
            inside = ~outside
174
            if np.any(inside):
175
                basis[inside, :] = spline_design(Aknots, x[inside], 4)
176
        else:
177
            basis = spline_design(Aknots, x, 4)
178
179
        const = spline_design(Aknots, boundary_knots, 4, [2, 2])
180
181
        if include_intercept is False:
182
            basis = basis[:, 1:]
183
            const = const[:, 1:]
184
185
        qr_const = np.linalg.qr(const.T, mode="complete")[0]
186
        self.basis = (qr_const.T @ basis.T).T[:, 2:]
187
188
    def predict(self, newx):
189
        """
190
        Evaluate the spline basis at given values
191
192
        Arguments
193
        ---------
194
        newx : ndarray
195
            new predictor variable to regenerate the spline from
196
197
        Returns
198
        -------
199
        ns
200
            a new natural spline object, evaluated at the given values
201
        """
202
        return ns(
203
            newx,
204
            knots=self.knots,
205
            boundary_knots=self.boundary_knots,
206
            include_intercept=self.include_intercept,
207
        )