[6ff4a8]: / src / moscot / base / problems / birth_death.py

Download this file

296 lines (245 with data), 11.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Sequence, Union
import numpy as np
import scanpy as sc
from anndata import AnnData
from moscot._logging import logger
from moscot._types import ArrayLike
from moscot.base.problems.problem import AbstractAdataAccess, OTProblem
from moscot.utils.data import apoptosis_markers, proliferation_markers
__all__ = ["BirthDeathProblem", "BirthDeathMixin"]
class BirthDeathMixin(AbstractAdataAccess):
"""Mixin class used to estimate cell proliferation and apoptosis.
Parameters
----------
args
Positional arguments.
kwargs
Keyword arguments.
"""
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self._proliferation_key: Optional[str] = None
self._apoptosis_key: Optional[str] = None
self._scaling: float = 1.0
self._prior_growth: Optional[ArrayLike] = None
def score_genes_for_marginals(
self,
gene_set_proliferation: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None,
gene_set_apoptosis: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None,
proliferation_key: str = "proliferation",
apoptosis_key: str = "apoptosis",
**kwargs: Any,
) -> "BirthDeathMixin":
"""Compute gene scores to obtain prior knowledge about proliferation and apoptosis.
The gene scores can be used in :meth:`~moscot.base.problems.BirthDeathProblem.estimate_marginals`
to estimate the initial growth rates as suggested in :cite:`schiebinger:19`
Parameters
----------
gene_set_proliferation
Set of proliferation marker genes. If a :class:`str`, it should
correspond to the organism in :func:`~moscot.utils.data.proliferation_markers`.
gene_set_apoptosis
Set of apoptosis marker genes. If a :class:`str`, it should
correspond to the organism in :func:`~moscot.utils.data.apoptosis_markers`.
proliferation_key
Key in :attr:`~anndata.AnnData.obs` where to store the proliferation scores.
apoptosis_key
Key in :attr:`~anndata.AnnData.obs` where to store the apoptosis scores.
kwargs
Keyword arguments for :func:`~scanpy.tl.score_genes`.
Returns
-------
Returns self and updates the following fields:
- :attr:`proliferation_key` - key in :attr:`~anndata.AnnData.obs` where proliferation scores are stored.
- :attr:`apoptosis_key` - key in :attr:`~anndata.AnnData.obs` where apoptosis scores are stored.
"""
if isinstance(gene_set_proliferation, str):
gene_set_proliferation = proliferation_markers(gene_set_proliferation) # type: ignore[arg-type]
if gene_set_proliferation is not None:
sc.tl.score_genes(self.adata, gene_set_proliferation, score_name=proliferation_key, **kwargs)
self.proliferation_key = proliferation_key
else:
self.proliferation_key = None
if isinstance(gene_set_apoptosis, str):
gene_set_apoptosis = apoptosis_markers(gene_set_apoptosis) # type: ignore[arg-type]
if gene_set_apoptosis is not None:
sc.tl.score_genes(self.adata, gene_set_apoptosis, score_name=apoptosis_key, **kwargs)
self.apoptosis_key = apoptosis_key
else:
self.apoptosis_key = None
if self.proliferation_key is None and self.apoptosis_key is None:
logger.warning(
"At least one of `gene_set_proliferation` or `gene_set_apoptosis` must be provided to score genes."
)
return self
@property
def proliferation_key(self) -> Optional[str]:
"""Key in :attr:`~anndata.AnnData.obs` where cell proliferation is stored."""
return self._proliferation_key
@proliferation_key.setter
def proliferation_key(self, key: Optional[str]) -> None:
if key is not None and key not in self.adata.obs:
raise KeyError(f"Unable to find proliferation data in `adata.obs[{key!r}]`.")
self._proliferation_key = key
@property
def apoptosis_key(self) -> Optional[str]:
"""Key in :attr:`~anndata.AnnData.obs` where cell apoptosis is stored."""
return self._apoptosis_key
@apoptosis_key.setter
def apoptosis_key(self, key: Optional[str]) -> None:
if key is not None and key not in self.adata.obs:
raise KeyError(f"Unable to find apoptosis data in `adata.obs[{key!r}]`.")
self._apoptosis_key = key
class BirthDeathProblem(BirthDeathMixin, OTProblem):
""":term:`OT` problem used to estimate the :term:`marginals` with the
`birth-death process <https://en.wikipedia.org/wiki/Birth%E2%80%93death_process>`_.
Parameters
----------
args
Positional arguments for :class:`~moscot.base.problems.OTProblem`.
kwargs
Keyword arguments for :class:`~moscot.base.problems.OTProblem`.
""" # noqa: D205
def estimate_marginals(
self,
adata: AnnData,
source: bool,
proliferation_key: Optional[str] = None,
apoptosis_key: Optional[str] = None,
scaling: Optional[float] = None,
beta_max: float = 1.7,
beta_min: float = 0.3,
beta_center: float = 0.25,
beta_width: float = 0.5,
delta_max: float = 1.7,
delta_min: float = 0.3,
delta_center: float = 0.1,
delta_width: float = 0.2,
) -> ArrayLike:
"""Estimate the source or target :term:`marginals` based on marker genes, either with the
`birth-death process <https://en.wikipedia.org/wiki/Birth%E2%80%93death_process>`_,
as suggested in :cite:`schiebinger:19`, or with an exponential kernel.
See :meth:`score_genes_for_marginals` on how to compute the proliferation and apoptosis scores.
Parameters
----------
adata
Annotated data object.
source
Whether to estimate the source or the target :term:`marginals`.
proliferation_key
Key in :attr:`~anndata.AnnData.obs` where proliferation scores are stored.
apoptosis_key
Key in :attr:`~anndata.AnnData.obs` where apoptosis scores are stored.
scaling
A parameter for prior growth rate estimation.
If :obj:`float` is passed, it will be used as a scaling parameter in an exponential kernel
with proliferation and apoptosis scores.
If :obj:`None`, parameters corresponding to the birth and death processes will be used.
beta_max
Argument for :func:`~moscot.base.problems.birth_death.beta`
beta_min
Argument for :func:`~moscot.base.problems.birth_death.beta`
beta_center
Argument for :func:`~moscot.base.problems.birth_death.beta`
beta_width
Argument for :func:`~moscot.base.problems.birth_death.beta`
delta_max
Argument for :func:`~moscot.base.problems.birth_death.delta`
delta_min
Argument for :func:`~moscot.base.problems.birth_death.delta`
delta_center
Argument for :func:`~moscot.base.problems.birth_death.delta`
delta_width
Argument for :func:`~moscot.base.problems.birth_death.delta`
Returns
-------
The estimated source or target marginals of shape ``[n,]`` or ``[m,]``, depending on the ``source``.
If ``source = True``, also updates the following fields:
- :attr:`prior_growth_rates` - prior estimate of the source growth rates.
Examples
--------
- See :doc:`../../notebooks/examples/problems/800_score_genes_for_marginals`
on examples how to use :meth:`~moscot.problems.time.TemporalProblem.score_genes_for_marginals`.
""" # noqa: D205
def estimate(key: Optional[str], *, fn: Callable[..., ArrayLike], **kwargs: Any) -> ArrayLike:
if key is None:
return np.zeros(adata.n_obs, dtype=float)
try:
return fn(adata.obs[key].values.astype(float), **kwargs)
except KeyError:
raise KeyError(f"Unable to get data from `adata.obs[{key}!r]`.") from None
if proliferation_key is None and apoptosis_key is None:
raise ValueError("At least one of `proliferation_key` or `apoptosis_key` must be specified.")
# TODO(michalk8): why does this need to be set?
self.proliferation_key = proliferation_key
self.apoptosis_key = apoptosis_key
if scaling:
beta_fn = delta_fn = lambda x: x
else:
beta_fn = partial(
beta, beta_max=beta_max, beta_min=beta_min, beta_center=beta_center, beta_width=beta_width
)
delta_fn = partial(
delta, delta_max=delta_max, delta_min=delta_min, delta_center=delta_center, delta_width=delta_width
)
scaling = 1.0
birth = estimate(proliferation_key, fn=beta_fn)
death = estimate(apoptosis_key, fn=delta_fn)
prior_growth = np.exp((birth - death) * self.delta / scaling)
scaling = np.sum(prior_growth)
normalized_growth = prior_growth / scaling
if source:
self._scaling = scaling
self._prior_growth = prior_growth
return normalized_growth
return np.full(self.adata_tgt.n_obs, fill_value=np.mean(normalized_growth))
@property
def adata(self) -> AnnData:
"""Annotated data object."""
return self.adata_src
@property
def prior_growth_rates(self) -> Optional[ArrayLike]:
"""Prior estimate of the source growth rates."""
if self._prior_growth is None:
return None
return np.asarray(np.power(self._prior_growth, 1.0 / self.delta))
@property
def posterior_growth_rates(self) -> Optional[ArrayLike]:
"""Posterior estimate of the source growth rates."""
if self.solution is None:
return None
if self.delta is None:
return self.solution.a * self.adata.n_obs
return np.asarray(self.solution.a * self._scaling) ** (1.0 / self.delta)
@property
def delta(self) -> float:
"""Time difference between the source and the target."""
if TYPE_CHECKING:
assert isinstance(self._src_key, float)
assert isinstance(self._tgt_key, float)
return self._tgt_key - self._src_key
def _logistic(x: ArrayLike, L: float, k: float, center: float = 0) -> ArrayLike:
"""Logistic function."""
return L / (1 + np.exp(-k * (x - center)))
def _gen_logistic(p: ArrayLike, sup: float, inf: float, center: float, width: float) -> ArrayLike:
"""Shifted logistic function."""
return inf + _logistic(p, L=sup - inf, k=4.0 / width, center=center)
def beta(
p: ArrayLike,
beta_max: float = 1.7,
beta_min: float = 0.3,
beta_center: float = 0.25,
beta_width: float = 0.5,
) -> ArrayLike:
"""Birth process."""
return _gen_logistic(p, beta_max, beta_min, beta_center, beta_width)
def delta(
a: ArrayLike,
delta_max: float = 1.7,
delta_min: float = 0.3,
delta_center: float = 0.1,
delta_width: float = 0.2,
) -> ArrayLike:
"""Death process."""
return _gen_logistic(a, delta_max, delta_min, delta_center, delta_width)