|
a |
|
b/sub-packages/bionemo-moco/documentation.md |
|
|
1 |
# Table of Contents |
|
|
2 |
|
|
|
3 |
* [moco](#moco) |
|
|
4 |
* [bionemo.moco.distributions](#mocodistributions) |
|
|
5 |
* [bionemo.moco.distributions.prior.distribution](#mocodistributionspriordistribution) |
|
|
6 |
* [bionemo.moco.distributions.prior.discrete.uniform](#mocodistributionspriordiscreteuniform) |
|
|
7 |
* [bionemo.moco.distributions.prior.discrete.custom](#mocodistributionspriordiscretecustom) |
|
|
8 |
* [bionemo.moco.distributions.prior.discrete](#mocodistributionspriordiscrete) |
|
|
9 |
* [bionemo.moco.distributions.prior.discrete.mask](#mocodistributionspriordiscretemask) |
|
|
10 |
* [bionemo.moco.distributions.prior.continuous.harmonic](#mocodistributionspriorcontinuousharmonic) |
|
|
11 |
* [bionemo.moco.distributions.prior.continuous](#mocodistributionspriorcontinuous) |
|
|
12 |
* [bionemo.moco.distributions.prior.continuous.gaussian](#mocodistributionspriorcontinuousgaussian) |
|
|
13 |
* [bionemo.moco.distributions.prior.continuous.utils](#mocodistributionspriorcontinuousutils) |
|
|
14 |
* [bionemo.moco.distributions.prior](#mocodistributionsprior) |
|
|
15 |
* [bionemo.moco.distributions.time.distribution](#mocodistributionstimedistribution) |
|
|
16 |
* [bionemo.moco.distributions.time.uniform](#mocodistributionstimeuniform) |
|
|
17 |
* [bionemo.moco.distributions.time.logit\_normal](#mocodistributionstimelogit_normal) |
|
|
18 |
* [bionemo.moco.distributions.time](#mocodistributionstime) |
|
|
19 |
* [bionemo.moco.distributions.time.beta](#mocodistributionstimebeta) |
|
|
20 |
* [bionemo.moco.distributions.time.utils](#mocodistributionstimeutils) |
|
|
21 |
* [bionemo.moco.schedules.noise.continuous\_snr\_transforms](#mocoschedulesnoisecontinuous_snr_transforms) |
|
|
22 |
* [bionemo.moco.schedules.noise.discrete\_noise\_schedules](#mocoschedulesnoisediscrete_noise_schedules) |
|
|
23 |
* [bionemo.moco.schedules.noise](#mocoschedulesnoise) |
|
|
24 |
* [bionemo.moco.schedules.noise.continuous\_noise\_transforms](#mocoschedulesnoisecontinuous_noise_transforms) |
|
|
25 |
* [bionemo.moco.schedules](#mocoschedules) |
|
|
26 |
* [bionemo.moco.schedules.utils](#mocoschedulesutils) |
|
|
27 |
* [bionemo.moco.schedules.inference\_time\_schedules](#mocoschedulesinference_time_schedules) |
|
|
28 |
* [bionemo.moco.interpolants.continuous\_time.discrete](#mocointerpolantscontinuous_timediscrete) |
|
|
29 |
* [bionemo.moco.interpolants.continuous\_time.discrete.mdlm](#mocointerpolantscontinuous_timediscretemdlm) |
|
|
30 |
* [bionemo.moco.interpolants.continuous\_time.discrete.discrete\_flow\_matching](#mocointerpolantscontinuous_timediscretediscrete_flow_matching) |
|
|
31 |
* [bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation.ot\_sampler](#mocointerpolantscontinuous_timecontinuousdata_augmentationot_sampler) |
|
|
32 |
* [bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation.equivariant\_ot\_sampler](#mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_sampler) |
|
|
33 |
* [bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation.kabsch\_augmentation](#mocointerpolantscontinuous_timecontinuousdata_augmentationkabsch_augmentation) |
|
|
34 |
* [bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation](#mocointerpolantscontinuous_timecontinuousdata_augmentation) |
|
|
35 |
* [bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation.augmentation\_types](#mocointerpolantscontinuous_timecontinuousdata_augmentationaugmentation_types) |
|
|
36 |
* [bionemo.moco.interpolants.continuous\_time.continuous](#mocointerpolantscontinuous_timecontinuous) |
|
|
37 |
* [bionemo.moco.interpolants.continuous\_time.continuous.vdm](#mocointerpolantscontinuous_timecontinuousvdm) |
|
|
38 |
* [bionemo.moco.interpolants.continuous\_time.continuous.continuous\_flow\_matching](#mocointerpolantscontinuous_timecontinuouscontinuous_flow_matching) |
|
|
39 |
* [bionemo.moco.interpolants.continuous\_time](#mocointerpolantscontinuous_time) |
|
|
40 |
* [bionemo.moco.interpolants](#mocointerpolants) |
|
|
41 |
* [bionemo.moco.interpolants.batch\_augmentation](#mocointerpolantsbatch_augmentation) |
|
|
42 |
* [bionemo.moco.interpolants.discrete\_time.discrete.d3pm](#mocointerpolantsdiscrete_timediscreted3pm) |
|
|
43 |
* [bionemo.moco.interpolants.discrete\_time.discrete](#mocointerpolantsdiscrete_timediscrete) |
|
|
44 |
* [bionemo.moco.interpolants.discrete\_time.continuous.ddpm](#mocointerpolantsdiscrete_timecontinuousddpm) |
|
|
45 |
* [bionemo.moco.interpolants.discrete\_time.continuous](#mocointerpolantsdiscrete_timecontinuous) |
|
|
46 |
* [bionemo.moco.interpolants.discrete\_time](#mocointerpolantsdiscrete_time) |
|
|
47 |
* [bionemo.moco.interpolants.discrete\_time.utils](#mocointerpolantsdiscrete_timeutils) |
|
|
48 |
* [bionemo.moco.interpolants.base\_interpolant](#mocointerpolantsbase_interpolant) |
|
|
49 |
* [bionemo.moco.testing](#mocotesting) |
|
|
50 |
* [bionemo.moco.testing.parallel\_test\_utils](#mocotestingparallel_test_utils) |
|
|
51 |
|
|
|
52 |
<a id="moco"></a> |
|
|
53 |
|
|
|
54 |
# moco |
|
|
55 |
|
|
|
56 |
<a id="mocodistributions"></a> |
|
|
57 |
|
|
|
58 |
# bionemo.moco.distributions |
|
|
59 |
|
|
|
60 |
<a id="mocodistributionspriordistribution"></a> |
|
|
61 |
|
|
|
62 |
# bionemo.moco.distributions.prior.distribution |
|
|
63 |
|
|
|
64 |
<a id="mocodistributionspriordistributionPriorDistribution"></a> |
|
|
65 |
|
|
|
66 |
## PriorDistribution Objects |
|
|
67 |
|
|
|
68 |
```python |
|
|
69 |
class PriorDistribution(ABC) |
|
|
70 |
``` |
|
|
71 |
|
|
|
72 |
An abstract base class representing a prior distribution. |
|
|
73 |
|
|
|
74 |
<a id="mocodistributionspriordistributionPriorDistributionsample"></a> |
|
|
75 |
|
|
|
76 |
#### sample |
|
|
77 |
|
|
|
78 |
```python |
|
|
79 |
@abstractmethod |
|
|
80 |
def sample(shape: Tuple, |
|
|
81 |
mask: Optional[Tensor] = None, |
|
|
82 |
device: Union[str, torch.device] = "cpu") -> Tensor |
|
|
83 |
``` |
|
|
84 |
|
|
|
85 |
Generates a specified number of samples from the time distribution. |
|
|
86 |
|
|
|
87 |
**Arguments**: |
|
|
88 |
|
|
|
89 |
- `shape` _Tuple_ - The shape of the samples to generate. |
|
|
90 |
- `mask` _Optional[Tensor], optional_ - A tensor indicating which samples should be masked. Defaults to None. |
|
|
91 |
- `device` _str, optional_ - The device on which to generate the samples. Defaults to "cpu". |
|
|
92 |
|
|
|
93 |
|
|
|
94 |
**Returns**: |
|
|
95 |
|
|
|
96 |
- `Float` - A tensor of samples. |
|
|
97 |
|
|
|
98 |
<a id="mocodistributionspriordistributionDiscretePriorDistribution"></a> |
|
|
99 |
|
|
|
100 |
## DiscretePriorDistribution Objects |
|
|
101 |
|
|
|
102 |
```python |
|
|
103 |
class DiscretePriorDistribution(PriorDistribution) |
|
|
104 |
``` |
|
|
105 |
|
|
|
106 |
An abstract base class representing a discrete prior distribution. |
|
|
107 |
|
|
|
108 |
<a id="mocodistributionspriordistributionDiscretePriorDistribution__init__"></a> |
|
|
109 |
|
|
|
110 |
#### \_\_init\_\_ |
|
|
111 |
|
|
|
112 |
```python |
|
|
113 |
def __init__(num_classes: int, prior_dist: Tensor) |
|
|
114 |
``` |
|
|
115 |
|
|
|
116 |
Initializes a DiscretePriorDistribution instance. |
|
|
117 |
|
|
|
118 |
**Arguments**: |
|
|
119 |
|
|
|
120 |
- `num_classes` _int_ - The number of classes in the discrete distribution. |
|
|
121 |
- `prior_dist` _Tensor_ - The prior distribution over the classes. |
|
|
122 |
|
|
|
123 |
|
|
|
124 |
**Returns**: |
|
|
125 |
|
|
|
126 |
None |
|
|
127 |
|
|
|
128 |
<a id="mocodistributionspriordistributionDiscretePriorDistributionget_num_classes"></a> |
|
|
129 |
|
|
|
130 |
#### get\_num\_classes |
|
|
131 |
|
|
|
132 |
```python |
|
|
133 |
def get_num_classes() -> int |
|
|
134 |
``` |
|
|
135 |
|
|
|
136 |
Getter for num_classes. |
|
|
137 |
|
|
|
138 |
<a id="mocodistributionspriordistributionDiscretePriorDistributionget_prior_dist"></a> |
|
|
139 |
|
|
|
140 |
#### get\_prior\_dist |
|
|
141 |
|
|
|
142 |
```python |
|
|
143 |
def get_prior_dist() -> Tensor |
|
|
144 |
``` |
|
|
145 |
|
|
|
146 |
Getter for prior_dist. |
|
|
147 |
|
|
|
148 |
<a id="mocodistributionspriordiscreteuniform"></a> |
|
|
149 |
|
|
|
150 |
# bionemo.moco.distributions.prior.discrete.uniform |
|
|
151 |
|
|
|
152 |
<a id="mocodistributionspriordiscreteuniformDiscreteUniformPrior"></a> |
|
|
153 |
|
|
|
154 |
## DiscreteUniformPrior Objects |
|
|
155 |
|
|
|
156 |
```python |
|
|
157 |
class DiscreteUniformPrior(DiscretePriorDistribution) |
|
|
158 |
``` |
|
|
159 |
|
|
|
160 |
A subclass representing a discrete uniform prior distribution. |
|
|
161 |
|
|
|
162 |
<a id="mocodistributionspriordiscreteuniformDiscreteUniformPrior__init__"></a> |
|
|
163 |
|
|
|
164 |
#### \_\_init\_\_ |
|
|
165 |
|
|
|
166 |
```python |
|
|
167 |
def __init__(num_classes: int = 10) -> None |
|
|
168 |
``` |
|
|
169 |
|
|
|
170 |
Initializes a discrete uniform prior distribution. |
|
|
171 |
|
|
|
172 |
**Arguments**: |
|
|
173 |
|
|
|
174 |
- `num_classes` _int_ - The number of classes in the discrete uniform distribution. Defaults to 10. |
|
|
175 |
|
|
|
176 |
<a id="mocodistributionspriordiscreteuniformDiscreteUniformPriorsample"></a> |
|
|
177 |
|
|
|
178 |
#### sample |
|
|
179 |
|
|
|
180 |
```python |
|
|
181 |
def sample(shape: Tuple, |
|
|
182 |
mask: Optional[Tensor] = None, |
|
|
183 |
device: Union[str, torch.device] = "cpu", |
|
|
184 |
rng_generator: Optional[torch.Generator] = None) -> Tensor |
|
|
185 |
``` |
|
|
186 |
|
|
|
187 |
Generates a specified number of samples. |
|
|
188 |
|
|
|
189 |
**Arguments**: |
|
|
190 |
|
|
|
191 |
- `shape` _Tuple_ - The shape of the samples to generate. |
|
|
192 |
- `device` _str_ - cpu or gpu. |
|
|
193 |
- `mask` _Optional[Tensor]_ - An optional mask to apply to the samples. Defaults to None. |
|
|
194 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
195 |
|
|
|
196 |
|
|
|
197 |
**Returns**: |
|
|
198 |
|
|
|
199 |
- `Float` - A tensor of samples. |
|
|
200 |
|
|
|
201 |
<a id="mocodistributionspriordiscretecustom"></a> |
|
|
202 |
|
|
|
203 |
# bionemo.moco.distributions.prior.discrete.custom |
|
|
204 |
|
|
|
205 |
<a id="mocodistributionspriordiscretecustomDiscreteCustomPrior"></a> |
|
|
206 |
|
|
|
207 |
## DiscreteCustomPrior Objects |
|
|
208 |
|
|
|
209 |
```python |
|
|
210 |
class DiscreteCustomPrior(DiscretePriorDistribution) |
|
|
211 |
``` |
|
|
212 |
|
|
|
213 |
A subclass representing a discrete custom prior distribution. |
|
|
214 |
|
|
|
215 |
This class allows for the creation of a prior distribution with a custom |
|
|
216 |
probability mass function defined by the `prior_dist` tensor. For example if my data has 4 classes and I want [.3, .2, .4, .1] as the probabilities of the 4 classes. |
|
|
217 |
|
|
|
218 |
<a id="mocodistributionspriordiscretecustomDiscreteCustomPrior__init__"></a> |
|
|
219 |
|
|
|
220 |
#### \_\_init\_\_ |
|
|
221 |
|
|
|
222 |
```python |
|
|
223 |
def __init__(prior_dist: Tensor, num_classes: int = 10) -> None |
|
|
224 |
``` |
|
|
225 |
|
|
|
226 |
Initializes a DiscreteCustomPrior distribution. |
|
|
227 |
|
|
|
228 |
**Arguments**: |
|
|
229 |
|
|
|
230 |
- `prior_dist` - A tensor representing the probability mass function of the prior distribution. |
|
|
231 |
- `num_classes` - The number of classes in the prior distribution. Defaults to 10. |
|
|
232 |
|
|
|
233 |
|
|
|
234 |
**Notes**: |
|
|
235 |
|
|
|
236 |
The `prior_dist` tensor should have a sum close to 1.0, as it represents a probability mass function. |
|
|
237 |
|
|
|
238 |
<a id="mocodistributionspriordiscretecustomDiscreteCustomPriorsample"></a> |
|
|
239 |
|
|
|
240 |
#### sample |
|
|
241 |
|
|
|
242 |
```python |
|
|
243 |
def sample(shape: Tuple, |
|
|
244 |
mask: Optional[Tensor] = None, |
|
|
245 |
device: Union[str, torch.device] = "cpu", |
|
|
246 |
rng_generator: Optional[torch.Generator] = None) -> Tensor |
|
|
247 |
``` |
|
|
248 |
|
|
|
249 |
Samples from the discrete custom prior distribution. |
|
|
250 |
|
|
|
251 |
**Arguments**: |
|
|
252 |
|
|
|
253 |
- `shape` - A tuple specifying the shape of the samples to generate. |
|
|
254 |
- `mask` - An optional tensor mask to apply to the samples, broadcastable to the sample shape. Defaults to None. |
|
|
255 |
- `device` - The device on which to generate the samples, specified as a string or a :class:`torch.device`. Defaults to "cpu". |
|
|
256 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
257 |
|
|
|
258 |
|
|
|
259 |
**Returns**: |
|
|
260 |
|
|
|
261 |
A tensor of samples drawn from the prior distribution. |
|
|
262 |
|
|
|
263 |
<a id="mocodistributionspriordiscrete"></a> |
|
|
264 |
|
|
|
265 |
# bionemo.moco.distributions.prior.discrete |
|
|
266 |
|
|
|
267 |
<a id="mocodistributionspriordiscretemask"></a> |
|
|
268 |
|
|
|
269 |
# bionemo.moco.distributions.prior.discrete.mask |
|
|
270 |
|
|
|
271 |
<a id="mocodistributionspriordiscretemaskDiscreteMaskedPrior"></a> |
|
|
272 |
|
|
|
273 |
## DiscreteMaskedPrior Objects |
|
|
274 |
|
|
|
275 |
```python |
|
|
276 |
class DiscreteMaskedPrior(DiscretePriorDistribution) |
|
|
277 |
``` |
|
|
278 |
|
|
|
279 |
A subclass representing a Discrete Masked prior distribution. |
|
|
280 |
|
|
|
281 |
<a id="mocodistributionspriordiscretemaskDiscreteMaskedPrior__init__"></a> |
|
|
282 |
|
|
|
283 |
#### \_\_init\_\_ |
|
|
284 |
|
|
|
285 |
```python |
|
|
286 |
def __init__(num_classes: int = 10, |
|
|
287 |
mask_dim: Optional[int] = None, |
|
|
288 |
inclusive: bool = True) -> None |
|
|
289 |
``` |
|
|
290 |
|
|
|
291 |
Discrete Masked prior distribution. |
|
|
292 |
|
|
|
293 |
Theres 3 ways I can think of defining the problem that are hard to mesh together. |
|
|
294 |
|
|
|
295 |
1. [..., M, ....] inclusive anywhere --> exisiting LLM tokenizer where the mask has a specific location not at the end |
|
|
296 |
2. [......, M] inclusive on end --> mask_dim = None with inclusive set to True default stick on the end |
|
|
297 |
3. [.....] + [M] exclusive --> the number of classes representes the number of data classes and one wishes to add a separate MASK dimension. |
|
|
298 |
- Note the pad_sample function is provided to help add this extra external dimension. |
|
|
299 |
|
|
|
300 |
**Arguments**: |
|
|
301 |
|
|
|
302 |
- `num_classes` _int_ - The number of classes in the distribution. Defaults to 10. |
|
|
303 |
- `mask_dim` _int_ - The index for the mask token. Defaults to num_classes - 1 if inclusive or num_classes if exclusive. |
|
|
304 |
- `inclusive` _bool_ - Whether the mask is included in the specified number of classes. |
|
|
305 |
If True, the mask is considered as one of the classes. |
|
|
306 |
If False, the mask is considered as an additional class. Defaults to True. |
|
|
307 |
|
|
|
308 |
<a id="mocodistributionspriordiscretemaskDiscreteMaskedPriorsample"></a> |
|
|
309 |
|
|
|
310 |
#### sample |
|
|
311 |
|
|
|
312 |
```python |
|
|
313 |
def sample(shape: Tuple, |
|
|
314 |
mask: Optional[Tensor] = None, |
|
|
315 |
device: Union[str, torch.device] = "cpu", |
|
|
316 |
rng_generator: Optional[torch.Generator] = None) -> Tensor |
|
|
317 |
``` |
|
|
318 |
|
|
|
319 |
Generates a specified number of samples. |
|
|
320 |
|
|
|
321 |
**Arguments**: |
|
|
322 |
|
|
|
323 |
- `shape` _Tuple_ - The shape of the samples to generate. |
|
|
324 |
- `device` _str_ - cpu or gpu. |
|
|
325 |
- `mask` _Optional[Tensor]_ - An optional mask to apply to the samples. Defaults to None. |
|
|
326 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
327 |
|
|
|
328 |
|
|
|
329 |
**Returns**: |
|
|
330 |
|
|
|
331 |
- `Float` - A tensor of samples. |
|
|
332 |
|
|
|
333 |
<a id="mocodistributionspriordiscretemaskDiscreteMaskedPrioris_masked"></a> |
|
|
334 |
|
|
|
335 |
#### is\_masked |
|
|
336 |
|
|
|
337 |
```python |
|
|
338 |
def is_masked(sample: Tensor) -> Tensor |
|
|
339 |
``` |
|
|
340 |
|
|
|
341 |
Creates a mask for whether a state is masked. |
|
|
342 |
|
|
|
343 |
**Arguments**: |
|
|
344 |
|
|
|
345 |
- `sample` _Tensor_ - The sample to check. |
|
|
346 |
|
|
|
347 |
|
|
|
348 |
**Returns**: |
|
|
349 |
|
|
|
350 |
- `Tensor` - A float tensor indicating whether the sample is masked. |
|
|
351 |
|
|
|
352 |
<a id="mocodistributionspriordiscretemaskDiscreteMaskedPriorpad_sample"></a> |
|
|
353 |
|
|
|
354 |
#### pad\_sample |
|
|
355 |
|
|
|
356 |
```python |
|
|
357 |
def pad_sample(sample: Tensor) -> Tensor |
|
|
358 |
``` |
|
|
359 |
|
|
|
360 |
Pads the input sample with zeros along the last dimension. |
|
|
361 |
|
|
|
362 |
**Arguments**: |
|
|
363 |
|
|
|
364 |
- `sample` _Tensor_ - The input sample to be padded. |
|
|
365 |
|
|
|
366 |
|
|
|
367 |
**Returns**: |
|
|
368 |
|
|
|
369 |
- `Tensor` - The padded sample. |
|
|
370 |
|
|
|
371 |
<a id="mocodistributionspriorcontinuousharmonic"></a> |
|
|
372 |
|
|
|
373 |
# bionemo.moco.distributions.prior.continuous.harmonic |
|
|
374 |
|
|
|
375 |
<a id="mocodistributionspriorcontinuousharmonicLinearHarmonicPrior"></a> |
|
|
376 |
|
|
|
377 |
## LinearHarmonicPrior Objects |
|
|
378 |
|
|
|
379 |
```python |
|
|
380 |
class LinearHarmonicPrior(PriorDistribution) |
|
|
381 |
``` |
|
|
382 |
|
|
|
383 |
A subclass representing a Linear Harmonic prior distribution from Jing et al. https://arxiv.org/abs/2304.02198. |
|
|
384 |
|
|
|
385 |
<a id="mocodistributionspriorcontinuousharmonicLinearHarmonicPrior__init__"></a> |
|
|
386 |
|
|
|
387 |
#### \_\_init\_\_ |
|
|
388 |
|
|
|
389 |
```python |
|
|
390 |
def __init__(length: Optional[int] = None, |
|
|
391 |
distance: Float = 3.8, |
|
|
392 |
center: Bool = False, |
|
|
393 |
rng_generator: Optional[torch.Generator] = None, |
|
|
394 |
device: Union[str, torch.device] = "cpu") -> None |
|
|
395 |
``` |
|
|
396 |
|
|
|
397 |
Linear Harmonic prior distribution. |
|
|
398 |
|
|
|
399 |
**Arguments**: |
|
|
400 |
|
|
|
401 |
- `length` _Optional[int]_ - The number of points in a batch. |
|
|
402 |
- `distance` _Float_ - RMS distance between adjacent points in the line graph. |
|
|
403 |
- `center` _bool_ - Whether to center the samples around the mean. Defaults to False. |
|
|
404 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
405 |
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu"). |
|
|
406 |
|
|
|
407 |
<a id="mocodistributionspriorcontinuousharmonicLinearHarmonicPriorsample"></a> |
|
|
408 |
|
|
|
409 |
#### sample |
|
|
410 |
|
|
|
411 |
```python |
|
|
412 |
def sample(shape: Tuple, |
|
|
413 |
mask: Optional[Tensor] = None, |
|
|
414 |
device: Union[str, torch.device] = "cpu", |
|
|
415 |
rng_generator: Optional[torch.Generator] = None) -> Tensor |
|
|
416 |
``` |
|
|
417 |
|
|
|
418 |
Generates a specified number of samples from the Harmonic prior distribution. |
|
|
419 |
|
|
|
420 |
**Arguments**: |
|
|
421 |
|
|
|
422 |
- `shape` _Tuple_ - The shape of the samples to generate. |
|
|
423 |
- `device` _str_ - cpu or gpu. |
|
|
424 |
- `mask` _Optional[Tensor]_ - An optional mask to apply to the samples. Defaults to None. |
|
|
425 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
426 |
|
|
|
427 |
|
|
|
428 |
**Returns**: |
|
|
429 |
|
|
|
430 |
- `Float` - A tensor of samples. |
|
|
431 |
|
|
|
432 |
<a id="mocodistributionspriorcontinuous"></a> |
|
|
433 |
|
|
|
434 |
# bionemo.moco.distributions.prior.continuous |
|
|
435 |
|
|
|
436 |
<a id="mocodistributionspriorcontinuousgaussian"></a> |
|
|
437 |
|
|
|
438 |
# bionemo.moco.distributions.prior.continuous.gaussian |
|
|
439 |
|
|
|
440 |
<a id="mocodistributionspriorcontinuousgaussianGaussianPrior"></a> |
|
|
441 |
|
|
|
442 |
## GaussianPrior Objects |
|
|
443 |
|
|
|
444 |
```python |
|
|
445 |
class GaussianPrior(PriorDistribution) |
|
|
446 |
``` |
|
|
447 |
|
|
|
448 |
A subclass representing a Gaussian prior distribution. |
|
|
449 |
|
|
|
450 |
<a id="mocodistributionspriorcontinuousgaussianGaussianPrior__init__"></a> |
|
|
451 |
|
|
|
452 |
#### \_\_init\_\_ |
|
|
453 |
|
|
|
454 |
```python |
|
|
455 |
def __init__(mean: Float = 0.0, |
|
|
456 |
std: Float = 1.0, |
|
|
457 |
center: Bool = False, |
|
|
458 |
rng_generator: Optional[torch.Generator] = None) -> None |
|
|
459 |
``` |
|
|
460 |
|
|
|
461 |
Gaussian prior distribution. |
|
|
462 |
|
|
|
463 |
**Arguments**: |
|
|
464 |
|
|
|
465 |
- `mean` _Float_ - The mean of the Gaussian distribution. Defaults to 0.0. |
|
|
466 |
- `std` _Float_ - The standard deviation of the Gaussian distribution. Defaults to 1.0. |
|
|
467 |
- `center` _bool_ - Whether to center the samples around the mean. Defaults to False. |
|
|
468 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
469 |
|
|
|
470 |
<a id="mocodistributionspriorcontinuousgaussianGaussianPriorsample"></a> |
|
|
471 |
|
|
|
472 |
#### sample |
|
|
473 |
|
|
|
474 |
```python |
|
|
475 |
def sample(shape: Tuple, |
|
|
476 |
mask: Optional[Tensor] = None, |
|
|
477 |
device: Union[str, torch.device] = "cpu", |
|
|
478 |
rng_generator: Optional[torch.Generator] = None) -> Tensor |
|
|
479 |
``` |
|
|
480 |
|
|
|
481 |
Generates a specified number of samples from the Gaussian prior distribution. |
|
|
482 |
|
|
|
483 |
**Arguments**: |
|
|
484 |
|
|
|
485 |
- `shape` _Tuple_ - The shape of the samples to generate. |
|
|
486 |
- `device` _str_ - cpu or gpu. |
|
|
487 |
- `mask` _Optional[Tensor]_ - An optional mask to apply to the samples. Defaults to None. |
|
|
488 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
489 |
|
|
|
490 |
|
|
|
491 |
**Returns**: |
|
|
492 |
|
|
|
493 |
- `Float` - A tensor of samples. |
|
|
494 |
|
|
|
495 |
<a id="mocodistributionspriorcontinuousutils"></a> |
|
|
496 |
|
|
|
497 |
# bionemo.moco.distributions.prior.continuous.utils |
|
|
498 |
|
|
|
499 |
<a id="mocodistributionspriorcontinuousutilsremove_center_of_mass"></a> |
|
|
500 |
|
|
|
501 |
#### remove\_center\_of\_mass |
|
|
502 |
|
|
|
503 |
```python |
|
|
504 |
def remove_center_of_mass(data: Tensor, |
|
|
505 |
mask: Optional[Tensor] = None) -> Tensor |
|
|
506 |
``` |
|
|
507 |
|
|
|
508 |
Calculates the center of mass (CoM) of the given data. |
|
|
509 |
|
|
|
510 |
**Arguments**: |
|
|
511 |
|
|
|
512 |
- `data` - The input data with shape (..., nodes, features). |
|
|
513 |
- `mask` - An optional binary mask to apply to the data with shape (..., nodes) to mask out interaction from CoM calculation. Defaults to None. |
|
|
514 |
|
|
|
515 |
|
|
|
516 |
**Returns**: |
|
|
517 |
|
|
|
518 |
The CoM of the data with shape (..., 1, features). |
|
|
519 |
|
|
|
520 |
<a id="mocodistributionsprior"></a> |
|
|
521 |
|
|
|
522 |
# bionemo.moco.distributions.prior |
|
|
523 |
|
|
|
524 |
<a id="mocodistributionstimedistribution"></a> |
|
|
525 |
|
|
|
526 |
# bionemo.moco.distributions.time.distribution |
|
|
527 |
|
|
|
528 |
<a id="mocodistributionstimedistributionTimeDistribution"></a> |
|
|
529 |
|
|
|
530 |
## TimeDistribution Objects |
|
|
531 |
|
|
|
532 |
```python |
|
|
533 |
class TimeDistribution(ABC) |
|
|
534 |
``` |
|
|
535 |
|
|
|
536 |
An abstract base class representing a time distribution. |
|
|
537 |
|
|
|
538 |
**Arguments**: |
|
|
539 |
|
|
|
540 |
- `discrete_time` _Bool_ - Whether the time is discrete. |
|
|
541 |
- `nsteps` _Optional[int]_ - Number of nsteps for discretization. |
|
|
542 |
- `min_t` _Optional[Float]_ - Min continuous time. |
|
|
543 |
- `max_t` _Optional[Float]_ - Max continuous time. |
|
|
544 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
545 |
|
|
|
546 |
<a id="mocodistributionstimedistributionTimeDistribution__init__"></a> |
|
|
547 |
|
|
|
548 |
#### \_\_init\_\_ |
|
|
549 |
|
|
|
550 |
```python |
|
|
551 |
def __init__(discrete_time: Bool = False, |
|
|
552 |
nsteps: Optional[int] = None, |
|
|
553 |
min_t: Optional[Float] = None, |
|
|
554 |
max_t: Optional[Float] = None, |
|
|
555 |
rng_generator: Optional[torch.Generator] = None) |
|
|
556 |
``` |
|
|
557 |
|
|
|
558 |
Initializes a TimeDistribution object. |
|
|
559 |
|
|
|
560 |
<a id="mocodistributionstimedistributionTimeDistributionsample"></a> |
|
|
561 |
|
|
|
562 |
#### sample |
|
|
563 |
|
|
|
564 |
```python |
|
|
565 |
@abstractmethod |
|
|
566 |
def sample(n_samples: Union[int, Tuple[int, ...], torch.Size], |
|
|
567 |
device: Union[str, torch.device] = "cpu", |
|
|
568 |
rng_generator: Optional[torch.Generator] = None) -> Float |
|
|
569 |
``` |
|
|
570 |
|
|
|
571 |
Generates a specified number of samples from the time distribution. |
|
|
572 |
|
|
|
573 |
**Arguments**: |
|
|
574 |
|
|
|
575 |
- `n_samples` _int_ - The number of samples to generate. |
|
|
576 |
- `device` _str_ - cpu or gpu. |
|
|
577 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
578 |
|
|
|
579 |
|
|
|
580 |
**Returns**: |
|
|
581 |
|
|
|
582 |
- `Float` - A list or array of samples. |
|
|
583 |
|
|
|
584 |
<a id="mocodistributionstimedistributionMixTimeDistribution"></a> |
|
|
585 |
|
|
|
586 |
## MixTimeDistribution Objects |
|
|
587 |
|
|
|
588 |
```python |
|
|
589 |
class MixTimeDistribution() |
|
|
590 |
``` |
|
|
591 |
|
|
|
592 |
An abstract base class representing a mixed time distribution. |
|
|
593 |
|
|
|
594 |
uniform_dist = UniformTimeDistribution(min_t=0.0, max_t=1.0, discrete_time=False) |
|
|
595 |
beta_dist = BetaTimeDistribution(min_t=0.0, max_t=1.0, discrete_time=False, p1=2.0, p2=1.0) |
|
|
596 |
mix_dist = MixTimeDistribution(uniform_dist, beta_dist, mix_fraction=0.5) |
|
|
597 |
|
|
|
598 |
<a id="mocodistributionstimedistributionMixTimeDistribution__init__"></a> |
|
|
599 |
|
|
|
600 |
#### \_\_init\_\_ |
|
|
601 |
|
|
|
602 |
```python |
|
|
603 |
def __init__(dist1: TimeDistribution, dist2: TimeDistribution, |
|
|
604 |
mix_fraction: Float) |
|
|
605 |
``` |
|
|
606 |
|
|
|
607 |
Initializes a MixTimeDistribution object. |
|
|
608 |
|
|
|
609 |
**Arguments**: |
|
|
610 |
|
|
|
611 |
- `dist1` _TimeDistribution_ - The first time distribution. |
|
|
612 |
- `dist2` _TimeDistribution_ - The second time distribution. |
|
|
613 |
- `mix_fraction` _Float_ - The fraction of samples to draw from dist1. Must be between 0 and 1. |
|
|
614 |
|
|
|
615 |
<a id="mocodistributionstimedistributionMixTimeDistributionsample"></a> |
|
|
616 |
|
|
|
617 |
#### sample |
|
|
618 |
|
|
|
619 |
```python |
|
|
620 |
def sample(n_samples: int, |
|
|
621 |
device: Union[str, torch.device] = "cpu", |
|
|
622 |
rng_generator: Optional[torch.Generator] = None) -> Float |
|
|
623 |
``` |
|
|
624 |
|
|
|
625 |
Generates a specified number of samples from the mixed time distribution. |
|
|
626 |
|
|
|
627 |
**Arguments**: |
|
|
628 |
|
|
|
629 |
- `n_samples` _int_ - The number of samples to generate. |
|
|
630 |
- `device` _str_ - cpu or gpu. |
|
|
631 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
632 |
|
|
|
633 |
|
|
|
634 |
**Returns**: |
|
|
635 |
|
|
|
636 |
- `Float` - A list or array of samples. |
|
|
637 |
|
|
|
638 |
<a id="mocodistributionstimeuniform"></a> |
|
|
639 |
|
|
|
640 |
# bionemo.moco.distributions.time.uniform |
|
|
641 |
|
|
|
642 |
<a id="mocodistributionstimeuniformUniformTimeDistribution"></a> |
|
|
643 |
|
|
|
644 |
## UniformTimeDistribution Objects |
|
|
645 |
|
|
|
646 |
```python |
|
|
647 |
class UniformTimeDistribution(TimeDistribution) |
|
|
648 |
``` |
|
|
649 |
|
|
|
650 |
A class representing a uniform time distribution. |
|
|
651 |
|
|
|
652 |
<a id="mocodistributionstimeuniformUniformTimeDistribution__init__"></a> |
|
|
653 |
|
|
|
654 |
#### \_\_init\_\_ |
|
|
655 |
|
|
|
656 |
```python |
|
|
657 |
def __init__(min_t: Float = 0.0, |
|
|
658 |
max_t: Float = 1.0, |
|
|
659 |
discrete_time: Bool = False, |
|
|
660 |
nsteps: Optional[int] = None, |
|
|
661 |
rng_generator: Optional[torch.Generator] = None) |
|
|
662 |
``` |
|
|
663 |
|
|
|
664 |
Initializes a UniformTimeDistribution object. |
|
|
665 |
|
|
|
666 |
**Arguments**: |
|
|
667 |
|
|
|
668 |
- `min_t` _Float_ - The minimum time value. |
|
|
669 |
- `max_t` _Float_ - The maximum time value. |
|
|
670 |
- `discrete_time` _Bool_ - Whether the time is discrete. |
|
|
671 |
- `nsteps` _Optional[int]_ - Number of nsteps for discretization. |
|
|
672 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
673 |
|
|
|
674 |
<a id="mocodistributionstimeuniformUniformTimeDistributionsample"></a> |
|
|
675 |
|
|
|
676 |
#### sample |
|
|
677 |
|
|
|
678 |
```python |
|
|
679 |
def sample(n_samples: Union[int, Tuple[int, ...], torch.Size], |
|
|
680 |
device: Union[str, torch.device] = "cpu", |
|
|
681 |
rng_generator: Optional[torch.Generator] = None) |
|
|
682 |
``` |
|
|
683 |
|
|
|
684 |
Generates a specified number of samples from the uniform time distribution. |
|
|
685 |
|
|
|
686 |
**Arguments**: |
|
|
687 |
|
|
|
688 |
- `n_samples` _int_ - The number of samples to generate. |
|
|
689 |
- `device` _str_ - cpu or gpu. |
|
|
690 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
691 |
|
|
|
692 |
|
|
|
693 |
**Returns**: |
|
|
694 |
|
|
|
695 |
A tensor of samples. |
|
|
696 |
|
|
|
697 |
<a id="mocodistributionstimeuniformSymmetricUniformTimeDistribution"></a> |
|
|
698 |
|
|
|
699 |
## SymmetricUniformTimeDistribution Objects |
|
|
700 |
|
|
|
701 |
```python |
|
|
702 |
class SymmetricUniformTimeDistribution(TimeDistribution) |
|
|
703 |
``` |
|
|
704 |
|
|
|
705 |
A class representing a uniform time distribution. |
|
|
706 |
|
|
|
707 |
<a id="mocodistributionstimeuniformSymmetricUniformTimeDistribution__init__"></a> |
|
|
708 |
|
|
|
709 |
#### \_\_init\_\_ |
|
|
710 |
|
|
|
711 |
```python |
|
|
712 |
def __init__(min_t: Float = 0.0, |
|
|
713 |
max_t: Float = 1.0, |
|
|
714 |
discrete_time: Bool = False, |
|
|
715 |
nsteps: Optional[int] = None, |
|
|
716 |
rng_generator: Optional[torch.Generator] = None) |
|
|
717 |
``` |
|
|
718 |
|
|
|
719 |
Initializes a UniformTimeDistribution object. |
|
|
720 |
|
|
|
721 |
**Arguments**: |
|
|
722 |
|
|
|
723 |
- `min_t` _Float_ - The minimum time value. |
|
|
724 |
- `max_t` _Float_ - The maximum time value. |
|
|
725 |
- `discrete_time` _Bool_ - Whether the time is discrete. |
|
|
726 |
- `nsteps` _Optional[int]_ - Number of nsteps for discretization. |
|
|
727 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
728 |
|
|
|
729 |
<a id="mocodistributionstimeuniformSymmetricUniformTimeDistributionsample"></a> |
|
|
730 |
|
|
|
731 |
#### sample |
|
|
732 |
|
|
|
733 |
```python |
|
|
734 |
def sample(n_samples: Union[int, Tuple[int, ...], torch.Size], |
|
|
735 |
device: Union[str, torch.device] = "cpu", |
|
|
736 |
rng_generator: Optional[torch.Generator] = None) |
|
|
737 |
``` |
|
|
738 |
|
|
|
739 |
Generates a specified number of samples from the uniform time distribution. |
|
|
740 |
|
|
|
741 |
**Arguments**: |
|
|
742 |
|
|
|
743 |
- `n_samples` _int_ - The number of samples to generate. |
|
|
744 |
- `device` _str_ - cpu or gpu. |
|
|
745 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
746 |
|
|
|
747 |
|
|
|
748 |
**Returns**: |
|
|
749 |
|
|
|
750 |
A tensor of samples. |
|
|
751 |
|
|
|
752 |
<a id="mocodistributionstimelogit_normal"></a> |
|
|
753 |
|
|
|
754 |
# bionemo.moco.distributions.time.logit\_normal |
|
|
755 |
|
|
|
756 |
<a id="mocodistributionstimelogit_normalLogitNormalTimeDistribution"></a> |
|
|
757 |
|
|
|
758 |
## LogitNormalTimeDistribution Objects |
|
|
759 |
|
|
|
760 |
```python |
|
|
761 |
class LogitNormalTimeDistribution(TimeDistribution) |
|
|
762 |
``` |
|
|
763 |
|
|
|
764 |
A class representing a logit normal time distribution. |
|
|
765 |
|
|
|
766 |
<a id="mocodistributionstimelogit_normalLogitNormalTimeDistribution__init__"></a> |
|
|
767 |
|
|
|
768 |
#### \_\_init\_\_ |
|
|
769 |
|
|
|
770 |
```python |
|
|
771 |
def __init__(p1: Float = 0.0, |
|
|
772 |
p2: Float = 1.0, |
|
|
773 |
min_t: Float = 0.0, |
|
|
774 |
max_t: Float = 1.0, |
|
|
775 |
discrete_time: Bool = False, |
|
|
776 |
nsteps: Optional[int] = None, |
|
|
777 |
rng_generator: Optional[torch.Generator] = None) |
|
|
778 |
``` |
|
|
779 |
|
|
|
780 |
Initializes a BetaTimeDistribution object. |
|
|
781 |
|
|
|
782 |
**Arguments**: |
|
|
783 |
|
|
|
784 |
- `p1` _Float_ - The first shape parameter of the logit normal distribution i.e. the mean. |
|
|
785 |
- `p2` _Float_ - The second shape parameter of the logit normal distribution i.e. the std. |
|
|
786 |
- `min_t` _Float_ - The minimum time value. |
|
|
787 |
- `max_t` _Float_ - The maximum time value. |
|
|
788 |
- `discrete_time` _Bool_ - Whether the time is discrete. |
|
|
789 |
- `nsteps` _Optional[int]_ - Number of nsteps for discretization. |
|
|
790 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
791 |
|
|
|
792 |
<a id="mocodistributionstimelogit_normalLogitNormalTimeDistributionsample"></a> |
|
|
793 |
|
|
|
794 |
#### sample |
|
|
795 |
|
|
|
796 |
```python |
|
|
797 |
def sample(n_samples: Union[int, Tuple[int, ...], torch.Size], |
|
|
798 |
device: Union[str, torch.device] = "cpu", |
|
|
799 |
rng_generator: Optional[torch.Generator] = None) |
|
|
800 |
``` |
|
|
801 |
|
|
|
802 |
Generates a specified number of samples from the uniform time distribution. |
|
|
803 |
|
|
|
804 |
**Arguments**: |
|
|
805 |
|
|
|
806 |
- `n_samples` _int_ - The number of samples to generate. |
|
|
807 |
- `device` _str_ - cpu or gpu. |
|
|
808 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
809 |
|
|
|
810 |
|
|
|
811 |
**Returns**: |
|
|
812 |
|
|
|
813 |
A tensor of samples. |
|
|
814 |
|
|
|
815 |
<a id="mocodistributionstime"></a> |
|
|
816 |
|
|
|
817 |
# bionemo.moco.distributions.time |
|
|
818 |
|
|
|
819 |
<a id="mocodistributionstimebeta"></a> |
|
|
820 |
|
|
|
821 |
# bionemo.moco.distributions.time.beta |
|
|
822 |
|
|
|
823 |
<a id="mocodistributionstimebetaBetaTimeDistribution"></a> |
|
|
824 |
|
|
|
825 |
## BetaTimeDistribution Objects |
|
|
826 |
|
|
|
827 |
```python |
|
|
828 |
class BetaTimeDistribution(TimeDistribution) |
|
|
829 |
``` |
|
|
830 |
|
|
|
831 |
A class representing a beta time distribution. |
|
|
832 |
|
|
|
833 |
<a id="mocodistributionstimebetaBetaTimeDistribution__init__"></a> |
|
|
834 |
|
|
|
835 |
#### \_\_init\_\_ |
|
|
836 |
|
|
|
837 |
```python |
|
|
838 |
def __init__(p1: Float = 2.0, |
|
|
839 |
p2: Float = 1.0, |
|
|
840 |
min_t: Float = 0.0, |
|
|
841 |
max_t: Float = 1.0, |
|
|
842 |
discrete_time: Bool = False, |
|
|
843 |
nsteps: Optional[int] = None, |
|
|
844 |
rng_generator: Optional[torch.Generator] = None) |
|
|
845 |
``` |
|
|
846 |
|
|
|
847 |
Initializes a BetaTimeDistribution object. |
|
|
848 |
|
|
|
849 |
**Arguments**: |
|
|
850 |
|
|
|
851 |
- `p1` _Float_ - The first shape parameter of the beta distribution. |
|
|
852 |
- `p2` _Float_ - The second shape parameter of the beta distribution. |
|
|
853 |
- `min_t` _Float_ - The minimum time value. |
|
|
854 |
- `max_t` _Float_ - The maximum time value. |
|
|
855 |
- `discrete_time` _Bool_ - Whether the time is discrete. |
|
|
856 |
- `nsteps` _Optional[int]_ - Number of nsteps for discretization. |
|
|
857 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
858 |
|
|
|
859 |
<a id="mocodistributionstimebetaBetaTimeDistributionsample"></a> |
|
|
860 |
|
|
|
861 |
#### sample |
|
|
862 |
|
|
|
863 |
```python |
|
|
864 |
def sample(n_samples: Union[int, Tuple[int, ...], torch.Size], |
|
|
865 |
device: Union[str, torch.device] = "cpu", |
|
|
866 |
rng_generator: Optional[torch.Generator] = None) |
|
|
867 |
``` |
|
|
868 |
|
|
|
869 |
Generates a specified number of samples from the uniform time distribution. |
|
|
870 |
|
|
|
871 |
**Arguments**: |
|
|
872 |
|
|
|
873 |
- `n_samples` _int_ - The number of samples to generate. |
|
|
874 |
- `device` _str_ - cpu or gpu. |
|
|
875 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
876 |
|
|
|
877 |
|
|
|
878 |
**Returns**: |
|
|
879 |
|
|
|
880 |
A tensor of samples. |
|
|
881 |
|
|
|
882 |
<a id="mocodistributionstimeutils"></a> |
|
|
883 |
|
|
|
884 |
# bionemo.moco.distributions.time.utils |
|
|
885 |
|
|
|
886 |
<a id="mocodistributionstimeutilsfloat_time_to_index"></a> |
|
|
887 |
|
|
|
888 |
#### float\_time\_to\_index |
|
|
889 |
|
|
|
890 |
```python |
|
|
891 |
def float_time_to_index(time: torch.Tensor, |
|
|
892 |
num_time_steps: int) -> torch.Tensor |
|
|
893 |
``` |
|
|
894 |
|
|
|
895 |
Convert a float time value to a time index. |
|
|
896 |
|
|
|
897 |
**Arguments**: |
|
|
898 |
|
|
|
899 |
- `time` _torch.Tensor_ - A tensor of float time values in the range [0, 1]. |
|
|
900 |
- `num_time_steps` _int_ - The number of discrete time steps. |
|
|
901 |
|
|
|
902 |
|
|
|
903 |
**Returns**: |
|
|
904 |
|
|
|
905 |
- `torch.Tensor` - A tensor of time indices corresponding to the input float time values. |
|
|
906 |
|
|
|
907 |
<a id="mocoschedulesnoisecontinuous_snr_transforms"></a> |
|
|
908 |
|
|
|
909 |
# bionemo.moco.schedules.noise.continuous\_snr\_transforms |
|
|
910 |
|
|
|
911 |
<a id="mocoschedulesnoisecontinuous_snr_transformslog"></a> |
|
|
912 |
|
|
|
913 |
#### log |
|
|
914 |
|
|
|
915 |
```python |
|
|
916 |
def log(t, eps=1e-20) |
|
|
917 |
``` |
|
|
918 |
|
|
|
919 |
Compute the natural logarithm of a tensor, clamping values to avoid numerical instability. |
|
|
920 |
|
|
|
921 |
**Arguments**: |
|
|
922 |
|
|
|
923 |
- `t` _Tensor_ - The input tensor. |
|
|
924 |
- `eps` _float, optional_ - The minimum value to clamp the input tensor (default is 1e-20). |
|
|
925 |
|
|
|
926 |
|
|
|
927 |
**Returns**: |
|
|
928 |
|
|
|
929 |
- `Tensor` - The natural logarithm of the input tensor. |
|
|
930 |
|
|
|
931 |
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransform"></a> |
|
|
932 |
|
|
|
933 |
## ContinuousSNRTransform Objects |
|
|
934 |
|
|
|
935 |
```python |
|
|
936 |
class ContinuousSNRTransform(ABC) |
|
|
937 |
``` |
|
|
938 |
|
|
|
939 |
A base class for continuous SNR schedules. |
|
|
940 |
|
|
|
941 |
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransform__init__"></a> |
|
|
942 |
|
|
|
943 |
#### \_\_init\_\_ |
|
|
944 |
|
|
|
945 |
```python |
|
|
946 |
def __init__(direction: TimeDirection) |
|
|
947 |
``` |
|
|
948 |
|
|
|
949 |
Initialize the DiscreteNoiseSchedule. |
|
|
950 |
|
|
|
951 |
**Arguments**: |
|
|
952 |
|
|
|
953 |
- `direction` _TimeDirection_ - required this defines in which direction the scheduler was built |
|
|
954 |
|
|
|
955 |
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransformcalculate_log_snr"></a> |
|
|
956 |
|
|
|
957 |
#### calculate\_log\_snr |
|
|
958 |
|
|
|
959 |
```python |
|
|
960 |
def calculate_log_snr(t: Tensor, |
|
|
961 |
device: Union[str, torch.device] = "cpu", |
|
|
962 |
synchronize: Optional[TimeDirection] = None) -> Tensor |
|
|
963 |
``` |
|
|
964 |
|
|
|
965 |
Public wrapper to generate the time schedule as a tensor. |
|
|
966 |
|
|
|
967 |
**Arguments**: |
|
|
968 |
|
|
|
969 |
- `t` _Tensor_ - The input tensor representing the time steps, with values ranging from 0 to 1. |
|
|
970 |
- `device` _Optional[str]_ - The device to place the schedule on. Defaults to "cpu". |
|
|
971 |
- `synchronize` _optional[TimeDirection]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, |
|
|
972 |
this parameter allows to flip the direction to match the specified one. Defaults to None. |
|
|
973 |
|
|
|
974 |
|
|
|
975 |
**Returns**: |
|
|
976 |
|
|
|
977 |
- `Tensor` - A tensor representing the log signal-to-noise (SNR) ratio for the given time steps. |
|
|
978 |
|
|
|
979 |
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransformlog_snr_to_alphas_sigmas"></a> |
|
|
980 |
|
|
|
981 |
#### log\_snr\_to\_alphas\_sigmas |
|
|
982 |
|
|
|
983 |
```python |
|
|
984 |
def log_snr_to_alphas_sigmas(log_snr: Tensor) -> Tuple[Tensor, Tensor] |
|
|
985 |
``` |
|
|
986 |
|
|
|
987 |
Converts log signal-to-noise ratio (SNR) to alpha and sigma values. |
|
|
988 |
|
|
|
989 |
**Arguments**: |
|
|
990 |
|
|
|
991 |
- `log_snr` _Tensor_ - The input log SNR tensor. |
|
|
992 |
|
|
|
993 |
|
|
|
994 |
**Returns**: |
|
|
995 |
|
|
|
996 |
tuple[Tensor, Tensor]: A tuple containing the squared root of alpha and sigma values. |
|
|
997 |
|
|
|
998 |
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransformderivative"></a> |
|
|
999 |
|
|
|
1000 |
#### derivative |
|
|
1001 |
|
|
|
1002 |
```python |
|
|
1003 |
def derivative(t: Tensor, func: Callable) -> Tensor |
|
|
1004 |
``` |
|
|
1005 |
|
|
|
1006 |
Compute derivative of a function, it supports bached single variable inputs. |
|
|
1007 |
|
|
|
1008 |
**Arguments**: |
|
|
1009 |
|
|
|
1010 |
- `t` _Tensor_ - time variable at which derivatives are taken |
|
|
1011 |
- `func` _Callable_ - function for derivative calculation |
|
|
1012 |
|
|
|
1013 |
|
|
|
1014 |
**Returns**: |
|
|
1015 |
|
|
|
1016 |
- `Tensor` - derivative that is detached from the computational graph |
|
|
1017 |
|
|
|
1018 |
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransformcalculate_general_sde_terms"></a> |
|
|
1019 |
|
|
|
1020 |
#### calculate\_general\_sde\_terms |
|
|
1021 |
|
|
|
1022 |
```python |
|
|
1023 |
def calculate_general_sde_terms(t) |
|
|
1024 |
``` |
|
|
1025 |
|
|
|
1026 |
Compute the general SDE terms for a given time step t. |
|
|
1027 |
|
|
|
1028 |
**Arguments**: |
|
|
1029 |
|
|
|
1030 |
- `t` _Tensor_ - The input tensor representing the time step. |
|
|
1031 |
|
|
|
1032 |
|
|
|
1033 |
**Returns**: |
|
|
1034 |
|
|
|
1035 |
tuple[Tensor, Tensor]: A tuple containing the drift term f_t and the diffusion term g_t_2. |
|
|
1036 |
|
|
|
1037 |
|
|
|
1038 |
**Notes**: |
|
|
1039 |
|
|
|
1040 |
This method computes the drift and diffusion terms of the general SDE, which can be used to simulate the stochastic process. |
|
|
1041 |
The drift term represents the deterministic part of the process, while the diffusion term represents the stochastic part. |
|
|
1042 |
|
|
|
1043 |
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransformcalculate_beta"></a> |
|
|
1044 |
|
|
|
1045 |
#### calculate\_beta |
|
|
1046 |
|
|
|
1047 |
```python |
|
|
1048 |
def calculate_beta(t) |
|
|
1049 |
``` |
|
|
1050 |
|
|
|
1051 |
Compute the drift coefficient for the OU process of the form $dx = -\frac{1}{2} \beta(t) x dt + sqrt(beta(t)) dw_t$. |
|
|
1052 |
|
|
|
1053 |
beta = d/dt log(alpha**2) = 2 * 1/alpha * d/dt(alpha) |
|
|
1054 |
|
|
|
1055 |
**Arguments**: |
|
|
1056 |
|
|
|
1057 |
- `t` _Union[float, Tensor]_ - t in [0, 1] |
|
|
1058 |
|
|
|
1059 |
|
|
|
1060 |
**Returns**: |
|
|
1061 |
|
|
|
1062 |
- `Tensor` - beta(t) |
|
|
1063 |
|
|
|
1064 |
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransformcalculate_alpha_log_snr"></a> |
|
|
1065 |
|
|
|
1066 |
#### calculate\_alpha\_log\_snr |
|
|
1067 |
|
|
|
1068 |
```python |
|
|
1069 |
def calculate_alpha_log_snr(log_snr: Tensor) -> Tensor |
|
|
1070 |
``` |
|
|
1071 |
|
|
|
1072 |
Compute alpha values based on the log SNR. |
|
|
1073 |
|
|
|
1074 |
**Arguments**: |
|
|
1075 |
|
|
|
1076 |
- `log_snr` _Tensor_ - The input tensor representing the log signal-to-noise ratio. |
|
|
1077 |
|
|
|
1078 |
|
|
|
1079 |
**Returns**: |
|
|
1080 |
|
|
|
1081 |
- `Tensor` - A tensor representing the alpha values for the given log SNR. |
|
|
1082 |
|
|
|
1083 |
|
|
|
1084 |
**Notes**: |
|
|
1085 |
|
|
|
1086 |
This method computes alpha values as the square root of the sigmoid of the log SNR. |
|
|
1087 |
|
|
|
1088 |
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransformcalculate_alpha_t"></a> |
|
|
1089 |
|
|
|
1090 |
#### calculate\_alpha\_t |
|
|
1091 |
|
|
|
1092 |
```python |
|
|
1093 |
def calculate_alpha_t(t: Tensor) -> Tensor |
|
|
1094 |
``` |
|
|
1095 |
|
|
|
1096 |
Compute alpha values based on the log SNR schedule. |
|
|
1097 |
|
|
|
1098 |
**Arguments**: |
|
|
1099 |
|
|
|
1100 |
- `t` _Tensor_ - The input tensor representing the time steps. |
|
|
1101 |
|
|
|
1102 |
|
|
|
1103 |
**Returns**: |
|
|
1104 |
|
|
|
1105 |
- `Tensor` - A tensor representing the alpha values for the given time steps. |
|
|
1106 |
|
|
|
1107 |
|
|
|
1108 |
**Notes**: |
|
|
1109 |
|
|
|
1110 |
This method computes alpha values as the square root of the sigmoid of the log SNR. |
|
|
1111 |
|
|
|
1112 |
<a id="mocoschedulesnoisecontinuous_snr_transformsCosineSNRTransform"></a> |
|
|
1113 |
|
|
|
1114 |
## CosineSNRTransform Objects |
|
|
1115 |
|
|
|
1116 |
```python |
|
|
1117 |
class CosineSNRTransform(ContinuousSNRTransform) |
|
|
1118 |
``` |
|
|
1119 |
|
|
|
1120 |
A cosine SNR schedule. |
|
|
1121 |
|
|
|
1122 |
**Arguments**: |
|
|
1123 |
|
|
|
1124 |
- `nu` _Optional[Float]_ - Hyperparameter for the cosine schedule exponent (default is 1.0). |
|
|
1125 |
- `s` _Optional[Float]_ - Hyperparameter for the cosine schedule shift (default is 0.008). |
|
|
1126 |
|
|
|
1127 |
<a id="mocoschedulesnoisecontinuous_snr_transformsCosineSNRTransform__init__"></a> |
|
|
1128 |
|
|
|
1129 |
#### \_\_init\_\_ |
|
|
1130 |
|
|
|
1131 |
```python |
|
|
1132 |
def __init__(nu: Float = 1.0, s: Float = 0.008) |
|
|
1133 |
``` |
|
|
1134 |
|
|
|
1135 |
Initialize the CosineNoiseSchedule. |
|
|
1136 |
|
|
|
1137 |
<a id="mocoschedulesnoisecontinuous_snr_transformsLinearSNRTransform"></a> |
|
|
1138 |
|
|
|
1139 |
## LinearSNRTransform Objects |
|
|
1140 |
|
|
|
1141 |
```python |
|
|
1142 |
class LinearSNRTransform(ContinuousSNRTransform) |
|
|
1143 |
``` |
|
|
1144 |
|
|
|
1145 |
A Linear SNR schedule. |
|
|
1146 |
|
|
|
1147 |
<a id="mocoschedulesnoisecontinuous_snr_transformsLinearSNRTransform__init__"></a> |
|
|
1148 |
|
|
|
1149 |
#### \_\_init\_\_ |
|
|
1150 |
|
|
|
1151 |
```python |
|
|
1152 |
def __init__(min_value: Float = 1.0e-4) |
|
|
1153 |
``` |
|
|
1154 |
|
|
|
1155 |
Initialize the Linear SNR Transform. |
|
|
1156 |
|
|
|
1157 |
**Arguments**: |
|
|
1158 |
|
|
|
1159 |
- `min_value` _Float_ - min vaue of SNR defaults to 1.e-4. |
|
|
1160 |
|
|
|
1161 |
<a id="mocoschedulesnoisecontinuous_snr_transformsLinearLogInterpolatedSNRTransform"></a> |
|
|
1162 |
|
|
|
1163 |
## LinearLogInterpolatedSNRTransform Objects |
|
|
1164 |
|
|
|
1165 |
```python |
|
|
1166 |
class LinearLogInterpolatedSNRTransform(ContinuousSNRTransform) |
|
|
1167 |
``` |
|
|
1168 |
|
|
|
1169 |
A Linear Log space interpolated SNR schedule. |
|
|
1170 |
|
|
|
1171 |
<a id="mocoschedulesnoisecontinuous_snr_transformsLinearLogInterpolatedSNRTransform__init__"></a> |
|
|
1172 |
|
|
|
1173 |
#### \_\_init\_\_ |
|
|
1174 |
|
|
|
1175 |
```python |
|
|
1176 |
def __init__(min_value: Float = -7.0, max_value=13.5) |
|
|
1177 |
``` |
|
|
1178 |
|
|
|
1179 |
Initialize the Linear log space interpolated SNR Schedule from Chroma. |
|
|
1180 |
|
|
|
1181 |
**Arguments**: |
|
|
1182 |
|
|
|
1183 |
- `min_value` _Float_ - The min log SNR value. |
|
|
1184 |
- `max_value` _Float_ - the max log SNR value. |
|
|
1185 |
|
|
|
1186 |
<a id="mocoschedulesnoisediscrete_noise_schedules"></a> |
|
|
1187 |
|
|
|
1188 |
# bionemo.moco.schedules.noise.discrete\_noise\_schedules |
|
|
1189 |
|
|
|
1190 |
<a id="mocoschedulesnoisediscrete_noise_schedulesDiscreteNoiseSchedule"></a> |
|
|
1191 |
|
|
|
1192 |
## DiscreteNoiseSchedule Objects |
|
|
1193 |
|
|
|
1194 |
```python |
|
|
1195 |
class DiscreteNoiseSchedule(ABC) |
|
|
1196 |
``` |
|
|
1197 |
|
|
|
1198 |
A base class for discrete noise schedules. |
|
|
1199 |
|
|
|
1200 |
<a id="mocoschedulesnoisediscrete_noise_schedulesDiscreteNoiseSchedule__init__"></a> |
|
|
1201 |
|
|
|
1202 |
#### \_\_init\_\_ |
|
|
1203 |
|
|
|
1204 |
```python |
|
|
1205 |
def __init__(nsteps: int, direction: TimeDirection) |
|
|
1206 |
``` |
|
|
1207 |
|
|
|
1208 |
Initialize the DiscreteNoiseSchedule. |
|
|
1209 |
|
|
|
1210 |
**Arguments**: |
|
|
1211 |
|
|
|
1212 |
- `nsteps` _int_ - number of discrete steps. |
|
|
1213 |
- `direction` _TimeDirection_ - required this defines in which direction the scheduler was built |
|
|
1214 |
|
|
|
1215 |
<a id="mocoschedulesnoisediscrete_noise_schedulesDiscreteNoiseSchedulegenerate_schedule"></a> |
|
|
1216 |
|
|
|
1217 |
#### generate\_schedule |
|
|
1218 |
|
|
|
1219 |
```python |
|
|
1220 |
def generate_schedule(nsteps: Optional[int] = None, |
|
|
1221 |
device: Union[str, torch.device] = "cpu", |
|
|
1222 |
synchronize: Optional[TimeDirection] = None) -> Tensor |
|
|
1223 |
``` |
|
|
1224 |
|
|
|
1225 |
Generate the noise schedule as a tensor. |
|
|
1226 |
|
|
|
1227 |
**Arguments**: |
|
|
1228 |
|
|
|
1229 |
- `nsteps` _Optional[int]_ - Number of time steps. If None, uses the value from initialization. |
|
|
1230 |
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu"). |
|
|
1231 |
- `synchronize` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, |
|
|
1232 |
this parameter allows to flip the direction to match the specified one (default is None). |
|
|
1233 |
|
|
|
1234 |
<a id="mocoschedulesnoisediscrete_noise_schedulesDiscreteNoiseSchedulecalculate_derivative"></a> |
|
|
1235 |
|
|
|
1236 |
#### calculate\_derivative |
|
|
1237 |
|
|
|
1238 |
```python |
|
|
1239 |
def calculate_derivative( |
|
|
1240 |
nsteps: Optional[int] = None, |
|
|
1241 |
device: Union[str, torch.device] = "cpu", |
|
|
1242 |
synchronize: Optional[TimeDirection] = None) -> Tensor |
|
|
1243 |
``` |
|
|
1244 |
|
|
|
1245 |
Calculate the time derivative of the schedule. |
|
|
1246 |
|
|
|
1247 |
**Arguments**: |
|
|
1248 |
|
|
|
1249 |
- `nsteps` _Optional[int]_ - Number of time steps. If None, uses the value from initialization. |
|
|
1250 |
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu"). |
|
|
1251 |
- `synchronize` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, |
|
|
1252 |
this parameter allows to flip the direction to match the specified one (default is None). |
|
|
1253 |
|
|
|
1254 |
|
|
|
1255 |
**Returns**: |
|
|
1256 |
|
|
|
1257 |
- `Tensor` - A tensor representing the time derivative of the schedule. |
|
|
1258 |
|
|
|
1259 |
|
|
|
1260 |
**Raises**: |
|
|
1261 |
|
|
|
1262 |
- `NotImplementedError` - If the derivative calculation is not implemented for this schedule. |
|
|
1263 |
|
|
|
1264 |
<a id="mocoschedulesnoisediscrete_noise_schedulesDiscreteCosineNoiseSchedule"></a> |
|
|
1265 |
|
|
|
1266 |
## DiscreteCosineNoiseSchedule Objects |
|
|
1267 |
|
|
|
1268 |
```python |
|
|
1269 |
class DiscreteCosineNoiseSchedule(DiscreteNoiseSchedule) |
|
|
1270 |
``` |
|
|
1271 |
|
|
|
1272 |
A cosine discrete noise schedule. |
|
|
1273 |
|
|
|
1274 |
<a id="mocoschedulesnoisediscrete_noise_schedulesDiscreteCosineNoiseSchedule__init__"></a> |
|
|
1275 |
|
|
|
1276 |
#### \_\_init\_\_ |
|
|
1277 |
|
|
|
1278 |
```python |
|
|
1279 |
def __init__(nsteps: int, nu: Float = 1.0, s: Float = 0.008) |
|
|
1280 |
``` |
|
|
1281 |
|
|
|
1282 |
Initialize the CosineNoiseSchedule. |
|
|
1283 |
|
|
|
1284 |
**Arguments**: |
|
|
1285 |
|
|
|
1286 |
- `nsteps` _int_ - Number of discrete steps. |
|
|
1287 |
- `nu` _Optional[Float]_ - Hyperparameter for the cosine schedule exponent (default is 1.0). |
|
|
1288 |
- `s` _Optional[Float]_ - Hyperparameter for the cosine schedule shift (default is 0.008). |
|
|
1289 |
|
|
|
1290 |
<a id="mocoschedulesnoisediscrete_noise_schedulesDiscreteLinearNoiseSchedule"></a> |
|
|
1291 |
|
|
|
1292 |
## DiscreteLinearNoiseSchedule Objects |
|
|
1293 |
|
|
|
1294 |
```python |
|
|
1295 |
class DiscreteLinearNoiseSchedule(DiscreteNoiseSchedule) |
|
|
1296 |
``` |
|
|
1297 |
|
|
|
1298 |
A linear discrete noise schedule. |
|
|
1299 |
|
|
|
1300 |
<a id="mocoschedulesnoisediscrete_noise_schedulesDiscreteLinearNoiseSchedule__init__"></a> |
|
|
1301 |
|
|
|
1302 |
#### \_\_init\_\_ |
|
|
1303 |
|
|
|
1304 |
```python |
|
|
1305 |
def __init__(nsteps: int, beta_start: Float = 1e-4, beta_end: Float = 0.02) |
|
|
1306 |
``` |
|
|
1307 |
|
|
|
1308 |
Initialize the CosineNoiseSchedule. |
|
|
1309 |
|
|
|
1310 |
**Arguments**: |
|
|
1311 |
|
|
|
1312 |
- `nsteps` _Optional[int]_ - Number of time steps. If None, uses the value from initialization. |
|
|
1313 |
- `beta_start` _Optional[int]_ - starting beta value. Defaults to 1e-4. |
|
|
1314 |
- `beta_end` _Optional[int]_ - end beta value. Defaults to 0.02. |
|
|
1315 |
|
|
|
1316 |
<a id="mocoschedulesnoise"></a> |
|
|
1317 |
|
|
|
1318 |
# bionemo.moco.schedules.noise |
|
|
1319 |
|
|
|
1320 |
<a id="mocoschedulesnoisecontinuous_noise_transforms"></a> |
|
|
1321 |
|
|
|
1322 |
# bionemo.moco.schedules.noise.continuous\_noise\_transforms |
|
|
1323 |
|
|
|
1324 |
<a id="mocoschedulesnoisecontinuous_noise_transformsContinuousExpNoiseTransform"></a> |
|
|
1325 |
|
|
|
1326 |
## ContinuousExpNoiseTransform Objects |
|
|
1327 |
|
|
|
1328 |
```python |
|
|
1329 |
class ContinuousExpNoiseTransform(ABC) |
|
|
1330 |
``` |
|
|
1331 |
|
|
|
1332 |
A base class for continuous schedules. |
|
|
1333 |
|
|
|
1334 |
alpha = exp(- sigma) where 1 - alpha controls the masking fraction. |
|
|
1335 |
|
|
|
1336 |
<a id="mocoschedulesnoisecontinuous_noise_transformsContinuousExpNoiseTransform__init__"></a> |
|
|
1337 |
|
|
|
1338 |
#### \_\_init\_\_ |
|
|
1339 |
|
|
|
1340 |
```python |
|
|
1341 |
def __init__(direction: TimeDirection) |
|
|
1342 |
``` |
|
|
1343 |
|
|
|
1344 |
Initialize the DiscreteNoiseSchedule. |
|
|
1345 |
|
|
|
1346 |
**Arguments**: |
|
|
1347 |
|
|
|
1348 |
direction : TimeDirection, required this defines in which direction the scheduler was built |
|
|
1349 |
|
|
|
1350 |
<a id="mocoschedulesnoisecontinuous_noise_transformsContinuousExpNoiseTransformcalculate_sigma"></a> |
|
|
1351 |
|
|
|
1352 |
#### calculate\_sigma |
|
|
1353 |
|
|
|
1354 |
```python |
|
|
1355 |
def calculate_sigma(t: Tensor, |
|
|
1356 |
device: Union[str, torch.device] = "cpu", |
|
|
1357 |
synchronize: Optional[TimeDirection] = None) -> Tensor |
|
|
1358 |
``` |
|
|
1359 |
|
|
|
1360 |
Calculate the sigma for the given time steps. |
|
|
1361 |
|
|
|
1362 |
**Arguments**: |
|
|
1363 |
|
|
|
1364 |
- `t` _Tensor_ - The input tensor representing the time steps, with values ranging from 0 to 1. |
|
|
1365 |
- `device` _Optional[str]_ - The device to place the schedule on. Defaults to "cpu". |
|
|
1366 |
- `synchronize` _optional[TimeDirection]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, |
|
|
1367 |
this parameter allows to flip the direction to match the specified one. Defaults to None. |
|
|
1368 |
|
|
|
1369 |
|
|
|
1370 |
**Returns**: |
|
|
1371 |
|
|
|
1372 |
- `Tensor` - A tensor representing the sigma values for the given time steps. |
|
|
1373 |
|
|
|
1374 |
|
|
|
1375 |
**Raises**: |
|
|
1376 |
|
|
|
1377 |
- `ValueError` - If the input time steps exceed the maximum allowed value of 1. |
|
|
1378 |
|
|
|
1379 |
<a id="mocoschedulesnoisecontinuous_noise_transformsContinuousExpNoiseTransformsigma_to_alpha"></a> |
|
|
1380 |
|
|
|
1381 |
#### sigma\_to\_alpha |
|
|
1382 |
|
|
|
1383 |
```python |
|
|
1384 |
def sigma_to_alpha(sigma: Tensor) -> Tensor |
|
|
1385 |
``` |
|
|
1386 |
|
|
|
1387 |
Converts sigma to alpha values by alpha = exp(- sigma). |
|
|
1388 |
|
|
|
1389 |
**Arguments**: |
|
|
1390 |
|
|
|
1391 |
- `sigma` _Tensor_ - The input sigma tensor. |
|
|
1392 |
|
|
|
1393 |
|
|
|
1394 |
**Returns**: |
|
|
1395 |
|
|
|
1396 |
- `Tensor` - A tensor containing the alpha values. |
|
|
1397 |
|
|
|
1398 |
<a id="mocoschedulesnoisecontinuous_noise_transformsCosineExpNoiseTransform"></a> |
|
|
1399 |
|
|
|
1400 |
## CosineExpNoiseTransform Objects |
|
|
1401 |
|
|
|
1402 |
```python |
|
|
1403 |
class CosineExpNoiseTransform(ContinuousExpNoiseTransform) |
|
|
1404 |
``` |
|
|
1405 |
|
|
|
1406 |
A cosine Exponential noise schedule. |
|
|
1407 |
|
|
|
1408 |
<a id="mocoschedulesnoisecontinuous_noise_transformsCosineExpNoiseTransform__init__"></a> |
|
|
1409 |
|
|
|
1410 |
#### \_\_init\_\_ |
|
|
1411 |
|
|
|
1412 |
```python |
|
|
1413 |
def __init__(eps: Float = 1.0e-3) |
|
|
1414 |
``` |
|
|
1415 |
|
|
|
1416 |
Initialize the CosineNoiseSchedule. |
|
|
1417 |
|
|
|
1418 |
**Arguments**: |
|
|
1419 |
|
|
|
1420 |
- `eps` _Float_ - small number to prevent numerical issues. |
|
|
1421 |
|
|
|
1422 |
<a id="mocoschedulesnoisecontinuous_noise_transformsCosineExpNoiseTransformd_dt_sigma"></a> |
|
|
1423 |
|
|
|
1424 |
#### d\_dt\_sigma |
|
|
1425 |
|
|
|
1426 |
```python |
|
|
1427 |
def d_dt_sigma(t: Tensor, device: Union[str, torch.device] = "cpu") -> Tensor |
|
|
1428 |
``` |
|
|
1429 |
|
|
|
1430 |
Compute the derivative of sigma with respect to time. |
|
|
1431 |
|
|
|
1432 |
**Arguments**: |
|
|
1433 |
|
|
|
1434 |
- `t` _Tensor_ - The input tensor representing the time steps. |
|
|
1435 |
- `device` _Optional[str]_ - The device to place the schedule on. Defaults to "cpu". |
|
|
1436 |
|
|
|
1437 |
|
|
|
1438 |
**Returns**: |
|
|
1439 |
|
|
|
1440 |
- `Tensor` - A tensor representing the derivative of sigma with respect to time. |
|
|
1441 |
|
|
|
1442 |
|
|
|
1443 |
**Notes**: |
|
|
1444 |
|
|
|
1445 |
The derivative of sigma as a function of time is given by: |
|
|
1446 |
|
|
|
1447 |
d/dt sigma(t) = d/dt (-log(cos(t * pi / 2) + eps)) |
|
|
1448 |
|
|
|
1449 |
Using the chain rule, we get: |
|
|
1450 |
|
|
|
1451 |
d/dt sigma(t) = (-1 / (cos(t * pi / 2) + eps)) * (-sin(t * pi / 2) * pi / 2) |
|
|
1452 |
|
|
|
1453 |
This is the derivative that is computed and returned by this method. |
|
|
1454 |
|
|
|
1455 |
<a id="mocoschedulesnoisecontinuous_noise_transformsLogLinearExpNoiseTransform"></a> |
|
|
1456 |
|
|
|
1457 |
## LogLinearExpNoiseTransform Objects |
|
|
1458 |
|
|
|
1459 |
```python |
|
|
1460 |
class LogLinearExpNoiseTransform(ContinuousExpNoiseTransform) |
|
|
1461 |
``` |
|
|
1462 |
|
|
|
1463 |
A log linear exponential schedule. |
|
|
1464 |
|
|
|
1465 |
<a id="mocoschedulesnoisecontinuous_noise_transformsLogLinearExpNoiseTransform__init__"></a> |
|
|
1466 |
|
|
|
1467 |
#### \_\_init\_\_ |
|
|
1468 |
|
|
|
1469 |
```python |
|
|
1470 |
def __init__(eps: Float = 1.0e-3) |
|
|
1471 |
``` |
|
|
1472 |
|
|
|
1473 |
Initialize the CosineNoiseSchedule. |
|
|
1474 |
|
|
|
1475 |
**Arguments**: |
|
|
1476 |
|
|
|
1477 |
- `eps` _Float_ - small value to prevent numerical issues. |
|
|
1478 |
|
|
|
1479 |
<a id="mocoschedulesnoisecontinuous_noise_transformsLogLinearExpNoiseTransformd_dt_sigma"></a> |
|
|
1480 |
|
|
|
1481 |
#### d\_dt\_sigma |
|
|
1482 |
|
|
|
1483 |
```python |
|
|
1484 |
def d_dt_sigma(t: Tensor, device: Union[str, torch.device] = "cpu") -> Tensor |
|
|
1485 |
``` |
|
|
1486 |
|
|
|
1487 |
Compute the derivative of sigma with respect to time. |
|
|
1488 |
|
|
|
1489 |
**Arguments**: |
|
|
1490 |
|
|
|
1491 |
- `t` _Tensor_ - The input tensor representing the time steps. |
|
|
1492 |
- `device` _Optional[str]_ - The device to place the schedule on. Defaults to "cpu". |
|
|
1493 |
|
|
|
1494 |
|
|
|
1495 |
**Returns**: |
|
|
1496 |
|
|
|
1497 |
- `Tensor` - A tensor representing the derivative of sigma with respect to time. |
|
|
1498 |
|
|
|
1499 |
<a id="mocoschedules"></a> |
|
|
1500 |
|
|
|
1501 |
# bionemo.moco.schedules |
|
|
1502 |
|
|
|
1503 |
<a id="mocoschedulesutils"></a> |
|
|
1504 |
|
|
|
1505 |
# bionemo.moco.schedules.utils |
|
|
1506 |
|
|
|
1507 |
<a id="mocoschedulesutilsTimeDirection"></a> |
|
|
1508 |
|
|
|
1509 |
## TimeDirection Objects |
|
|
1510 |
|
|
|
1511 |
```python |
|
|
1512 |
class TimeDirection(Enum) |
|
|
1513 |
``` |
|
|
1514 |
|
|
|
1515 |
Enum for the direction of the noise schedule. |
|
|
1516 |
|
|
|
1517 |
<a id="mocoschedulesutilsTimeDirectionUNIFIED"></a> |
|
|
1518 |
|
|
|
1519 |
#### UNIFIED |
|
|
1520 |
|
|
|
1521 |
Noise(0) --> Data(1) |
|
|
1522 |
|
|
|
1523 |
<a id="mocoschedulesutilsTimeDirectionDIFFUSION"></a> |
|
|
1524 |
|
|
|
1525 |
#### DIFFUSION |
|
|
1526 |
|
|
|
1527 |
Noise(1) --> Data(0) |
|
|
1528 |
|
|
|
1529 |
<a id="mocoschedulesinference_time_schedules"></a> |
|
|
1530 |
|
|
|
1531 |
# bionemo.moco.schedules.inference\_time\_schedules |
|
|
1532 |
|
|
|
1533 |
<a id="mocoschedulesinference_time_schedulesInferenceSchedule"></a> |
|
|
1534 |
|
|
|
1535 |
## InferenceSchedule Objects |
|
|
1536 |
|
|
|
1537 |
```python |
|
|
1538 |
class InferenceSchedule(ABC) |
|
|
1539 |
``` |
|
|
1540 |
|
|
|
1541 |
A base class for inference time schedules. |
|
|
1542 |
|
|
|
1543 |
<a id="mocoschedulesinference_time_schedulesInferenceSchedule__init__"></a> |
|
|
1544 |
|
|
|
1545 |
#### \_\_init\_\_ |
|
|
1546 |
|
|
|
1547 |
```python |
|
|
1548 |
def __init__(nsteps: int, |
|
|
1549 |
min_t: Float = 0, |
|
|
1550 |
padding: Float = 0, |
|
|
1551 |
dilation: Float = 0, |
|
|
1552 |
direction: Union[TimeDirection, str] = TimeDirection.UNIFIED, |
|
|
1553 |
device: Union[str, torch.device] = "cpu") |
|
|
1554 |
``` |
|
|
1555 |
|
|
|
1556 |
Initialize the InferenceSchedule. |
|
|
1557 |
|
|
|
1558 |
**Arguments**: |
|
|
1559 |
|
|
|
1560 |
- `nsteps` _int_ - Number of time steps. |
|
|
1561 |
- `min_t` _Float_ - minimum time value defaults to 0. |
|
|
1562 |
- `padding` _Float_ - padding time value defaults to 0. |
|
|
1563 |
- `dilation` _Float_ - dilation time value defaults to 0 ie the number of replicates. |
|
|
1564 |
- `direction` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None). |
|
|
1565 |
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu"). |
|
|
1566 |
|
|
|
1567 |
<a id="mocoschedulesinference_time_schedulesInferenceSchedulegenerate_schedule"></a> |
|
|
1568 |
|
|
|
1569 |
#### generate\_schedule |
|
|
1570 |
|
|
|
1571 |
```python |
|
|
1572 |
@abstractmethod |
|
|
1573 |
def generate_schedule( |
|
|
1574 |
nsteps: Optional[int] = None, |
|
|
1575 |
device: Optional[Union[str, torch.device]] = None) -> Tensor |
|
|
1576 |
``` |
|
|
1577 |
|
|
|
1578 |
Generate the time schedule as a tensor. |
|
|
1579 |
|
|
|
1580 |
**Arguments**: |
|
|
1581 |
|
|
|
1582 |
- `nsteps` _Optioanl[int]_ - Number of time steps. If None, uses the value from initialization. |
|
|
1583 |
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu"). |
|
|
1584 |
|
|
|
1585 |
<a id="mocoschedulesinference_time_schedulesInferenceSchedulepad_time"></a> |
|
|
1586 |
|
|
|
1587 |
#### pad\_time |
|
|
1588 |
|
|
|
1589 |
```python |
|
|
1590 |
def pad_time(n_samples: int, |
|
|
1591 |
scalar_time: Float, |
|
|
1592 |
device: Optional[Union[str, torch.device]] = None) -> Tensor |
|
|
1593 |
``` |
|
|
1594 |
|
|
|
1595 |
Creates a tensor of shape (n_samples,) filled with a scalar time value. |
|
|
1596 |
|
|
|
1597 |
**Arguments**: |
|
|
1598 |
|
|
|
1599 |
- `n_samples` _int_ - The desired dimension of the output tensor. |
|
|
1600 |
- `scalar_time` _Float_ - The scalar time value to fill the tensor with. |
|
|
1601 |
device (Optional[Union[str, torch.device]], optional): |
|
|
1602 |
The device to place the tensor on. Defaults to None, which uses the default device. |
|
|
1603 |
|
|
|
1604 |
|
|
|
1605 |
**Returns**: |
|
|
1606 |
|
|
|
1607 |
- `Tensor` - A tensor of shape (n_samples,) filled with the scalar time value. |
|
|
1608 |
|
|
|
1609 |
<a id="mocoschedulesinference_time_schedulesContinuousInferenceSchedule"></a> |
|
|
1610 |
|
|
|
1611 |
## ContinuousInferenceSchedule Objects |
|
|
1612 |
|
|
|
1613 |
```python |
|
|
1614 |
class ContinuousInferenceSchedule(InferenceSchedule) |
|
|
1615 |
``` |
|
|
1616 |
|
|
|
1617 |
A base class for continuous time inference schedules. |
|
|
1618 |
|
|
|
1619 |
<a id="mocoschedulesinference_time_schedulesContinuousInferenceSchedule__init__"></a> |
|
|
1620 |
|
|
|
1621 |
#### \_\_init\_\_ |
|
|
1622 |
|
|
|
1623 |
```python |
|
|
1624 |
def __init__(nsteps: int, |
|
|
1625 |
inclusive_end: bool = False, |
|
|
1626 |
min_t: Float = 0, |
|
|
1627 |
padding: Float = 0, |
|
|
1628 |
dilation: Float = 0, |
|
|
1629 |
direction: Union[TimeDirection, str] = TimeDirection.UNIFIED, |
|
|
1630 |
device: Union[str, torch.device] = "cpu") |
|
|
1631 |
``` |
|
|
1632 |
|
|
|
1633 |
Initialize the ContinuousInferenceSchedule. |
|
|
1634 |
|
|
|
1635 |
**Arguments**: |
|
|
1636 |
|
|
|
1637 |
- `nsteps` _int_ - Number of time steps. |
|
|
1638 |
- `inclusive_end` _bool_ - If True, include the end value (1.0) in the schedule otherwise ends at 1.0-1/nsteps (default is False). |
|
|
1639 |
- `min_t` _Float_ - minimum time value defaults to 0. |
|
|
1640 |
- `padding` _Float_ - padding time value defaults to 0. |
|
|
1641 |
- `dilation` _Float_ - dilation time value defaults to 0 ie the number of replicates. |
|
|
1642 |
- `direction` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None). |
|
|
1643 |
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu"). |
|
|
1644 |
|
|
|
1645 |
<a id="mocoschedulesinference_time_schedulesContinuousInferenceSchedulediscretize"></a> |
|
|
1646 |
|
|
|
1647 |
#### discretize |
|
|
1648 |
|
|
|
1649 |
```python |
|
|
1650 |
def discretize(nsteps: Optional[int] = None, |
|
|
1651 |
schedule: Optional[Tensor] = None, |
|
|
1652 |
device: Optional[Union[str, torch.device]] = None) -> Tensor |
|
|
1653 |
``` |
|
|
1654 |
|
|
|
1655 |
Discretize the time schedule into a list of time deltas. |
|
|
1656 |
|
|
|
1657 |
**Arguments**: |
|
|
1658 |
|
|
|
1659 |
- `nsteps` _Optioanl[int]_ - Number of time steps. If None, uses the value from initialization. |
|
|
1660 |
- `schedule` _Optional[Tensor]_ - Time scheudle if None will generate it with generate_schedule. |
|
|
1661 |
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu"). |
|
|
1662 |
|
|
|
1663 |
|
|
|
1664 |
**Returns**: |
|
|
1665 |
|
|
|
1666 |
- `Tensor` - A tensor of time deltas. |
|
|
1667 |
|
|
|
1668 |
<a id="mocoschedulesinference_time_schedulesDiscreteInferenceSchedule"></a> |
|
|
1669 |
|
|
|
1670 |
## DiscreteInferenceSchedule Objects |
|
|
1671 |
|
|
|
1672 |
```python |
|
|
1673 |
class DiscreteInferenceSchedule(InferenceSchedule) |
|
|
1674 |
``` |
|
|
1675 |
|
|
|
1676 |
A base class for discrete time inference schedules. |
|
|
1677 |
|
|
|
1678 |
<a id="mocoschedulesinference_time_schedulesDiscreteInferenceSchedulediscretize"></a> |
|
|
1679 |
|
|
|
1680 |
#### discretize |
|
|
1681 |
|
|
|
1682 |
```python |
|
|
1683 |
def discretize(nsteps: Optional[int] = None, |
|
|
1684 |
device: Optional[Union[str, torch.device]] = None) -> Tensor |
|
|
1685 |
``` |
|
|
1686 |
|
|
|
1687 |
Discretize the time schedule into a list of time deltas. |
|
|
1688 |
|
|
|
1689 |
**Arguments**: |
|
|
1690 |
|
|
|
1691 |
- `nsteps` _Optioanl[int]_ - Number of time steps. If None, uses the value from initialization. |
|
|
1692 |
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu"). |
|
|
1693 |
|
|
|
1694 |
|
|
|
1695 |
**Returns**: |
|
|
1696 |
|
|
|
1697 |
- `Tensor` - A tensor of time deltas. |
|
|
1698 |
|
|
|
1699 |
<a id="mocoschedulesinference_time_schedulesDiscreteLinearInferenceSchedule"></a> |
|
|
1700 |
|
|
|
1701 |
## DiscreteLinearInferenceSchedule Objects |
|
|
1702 |
|
|
|
1703 |
```python |
|
|
1704 |
class DiscreteLinearInferenceSchedule(DiscreteInferenceSchedule) |
|
|
1705 |
``` |
|
|
1706 |
|
|
|
1707 |
A linear time schedule for discrete time inference. |
|
|
1708 |
|
|
|
1709 |
<a id="mocoschedulesinference_time_schedulesDiscreteLinearInferenceSchedule__init__"></a> |
|
|
1710 |
|
|
|
1711 |
#### \_\_init\_\_ |
|
|
1712 |
|
|
|
1713 |
```python |
|
|
1714 |
def __init__(nsteps: int, |
|
|
1715 |
min_t: Float = 0, |
|
|
1716 |
padding: Float = 0, |
|
|
1717 |
dilation: Float = 0, |
|
|
1718 |
direction: Union[TimeDirection, str] = TimeDirection.UNIFIED, |
|
|
1719 |
device: Union[str, torch.device] = "cpu") |
|
|
1720 |
``` |
|
|
1721 |
|
|
|
1722 |
Initialize the DiscreteLinearInferenceSchedule. |
|
|
1723 |
|
|
|
1724 |
**Arguments**: |
|
|
1725 |
|
|
|
1726 |
- `nsteps` _int_ - Number of time steps. |
|
|
1727 |
- `min_t` _Float_ - minimum time value defaults to 0. |
|
|
1728 |
- `padding` _Float_ - padding time value defaults to 0. |
|
|
1729 |
- `dilation` _Float_ - dilation time value defaults to 0 ie the number of replicates. |
|
|
1730 |
- `direction` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None). |
|
|
1731 |
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu"). |
|
|
1732 |
|
|
|
1733 |
<a id="mocoschedulesinference_time_schedulesDiscreteLinearInferenceSchedulegenerate_schedule"></a> |
|
|
1734 |
|
|
|
1735 |
#### generate\_schedule |
|
|
1736 |
|
|
|
1737 |
```python |
|
|
1738 |
def generate_schedule( |
|
|
1739 |
nsteps: Optional[int] = None, |
|
|
1740 |
device: Optional[Union[str, torch.device]] = None) -> Tensor |
|
|
1741 |
``` |
|
|
1742 |
|
|
|
1743 |
Generate the linear time schedule as a tensor. |
|
|
1744 |
|
|
|
1745 |
**Arguments**: |
|
|
1746 |
|
|
|
1747 |
- `nsteps` _Optional[int]_ - Number of time steps. If None uses the value from initialization. |
|
|
1748 |
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu"). |
|
|
1749 |
|
|
|
1750 |
|
|
|
1751 |
**Returns**: |
|
|
1752 |
|
|
|
1753 |
- `Tensor` - A tensor of time steps. |
|
|
1754 |
- `Tensor` - A tensor of time steps. |
|
|
1755 |
|
|
|
1756 |
<a id="mocoschedulesinference_time_schedulesLinearInferenceSchedule"></a> |
|
|
1757 |
|
|
|
1758 |
## LinearInferenceSchedule Objects |
|
|
1759 |
|
|
|
1760 |
```python |
|
|
1761 |
class LinearInferenceSchedule(ContinuousInferenceSchedule) |
|
|
1762 |
``` |
|
|
1763 |
|
|
|
1764 |
A linear time schedule for continuous time inference. |
|
|
1765 |
|
|
|
1766 |
<a id="mocoschedulesinference_time_schedulesLinearInferenceSchedule__init__"></a> |
|
|
1767 |
|
|
|
1768 |
#### \_\_init\_\_ |
|
|
1769 |
|
|
|
1770 |
```python |
|
|
1771 |
def __init__(nsteps: int, |
|
|
1772 |
inclusive_end: bool = False, |
|
|
1773 |
min_t: Float = 0, |
|
|
1774 |
padding: Float = 0, |
|
|
1775 |
dilation: Float = 0, |
|
|
1776 |
direction: Union[TimeDirection, str] = TimeDirection.UNIFIED, |
|
|
1777 |
device: Union[str, torch.device] = "cpu") |
|
|
1778 |
``` |
|
|
1779 |
|
|
|
1780 |
Initialize the LinearInferenceSchedule. |
|
|
1781 |
|
|
|
1782 |
**Arguments**: |
|
|
1783 |
|
|
|
1784 |
- `nsteps` _int_ - Number of time steps. |
|
|
1785 |
- `inclusive_end` _bool_ - If True, include the end value (1.0) in the schedule otherwise ends at 1.0-1/nsteps (default is False). |
|
|
1786 |
- `min_t` _Float_ - minimum time value defaults to 0. |
|
|
1787 |
- `padding` _Float_ - padding time value defaults to 0. |
|
|
1788 |
- `dilation` _Float_ - dilation time value defaults to 0 ie the number of replicates. |
|
|
1789 |
- `direction` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None). |
|
|
1790 |
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu"). |
|
|
1791 |
|
|
|
1792 |
<a id="mocoschedulesinference_time_schedulesLinearInferenceSchedulegenerate_schedule"></a> |
|
|
1793 |
|
|
|
1794 |
#### generate\_schedule |
|
|
1795 |
|
|
|
1796 |
```python |
|
|
1797 |
def generate_schedule( |
|
|
1798 |
nsteps: Optional[int] = None, |
|
|
1799 |
device: Optional[Union[str, torch.device]] = None) -> Tensor |
|
|
1800 |
``` |
|
|
1801 |
|
|
|
1802 |
Generate the linear time schedule as a tensor. |
|
|
1803 |
|
|
|
1804 |
**Arguments**: |
|
|
1805 |
|
|
|
1806 |
- `nsteps` _Optional[int]_ - Number of time steps. If None uses the value from initialization. |
|
|
1807 |
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu"). |
|
|
1808 |
|
|
|
1809 |
|
|
|
1810 |
**Returns**: |
|
|
1811 |
|
|
|
1812 |
- `Tensor` - A tensor of time steps. |
|
|
1813 |
|
|
|
1814 |
<a id="mocoschedulesinference_time_schedulesPowerInferenceSchedule"></a> |
|
|
1815 |
|
|
|
1816 |
## PowerInferenceSchedule Objects |
|
|
1817 |
|
|
|
1818 |
```python |
|
|
1819 |
class PowerInferenceSchedule(ContinuousInferenceSchedule) |
|
|
1820 |
``` |
|
|
1821 |
|
|
|
1822 |
A power time schedule for inference, where time steps are generated by raising a uniform schedule to a specified power. |
|
|
1823 |
|
|
|
1824 |
<a id="mocoschedulesinference_time_schedulesPowerInferenceSchedule__init__"></a> |
|
|
1825 |
|
|
|
1826 |
#### \_\_init\_\_ |
|
|
1827 |
|
|
|
1828 |
```python |
|
|
1829 |
def __init__(nsteps: int, |
|
|
1830 |
inclusive_end: bool = False, |
|
|
1831 |
min_t: Float = 0, |
|
|
1832 |
padding: Float = 0, |
|
|
1833 |
dilation: Float = 0, |
|
|
1834 |
exponent: Float = 1.0, |
|
|
1835 |
direction: Union[TimeDirection, str] = TimeDirection.UNIFIED, |
|
|
1836 |
device: Union[str, torch.device] = "cpu") |
|
|
1837 |
``` |
|
|
1838 |
|
|
|
1839 |
Initialize the PowerInferenceSchedule. |
|
|
1840 |
|
|
|
1841 |
**Arguments**: |
|
|
1842 |
|
|
|
1843 |
- `nsteps` _int_ - Number of time steps. |
|
|
1844 |
- `inclusive_end` _bool_ - If True, include the end value (1.0) in the schedule otherwise ends at <1.0 (default is False). |
|
|
1845 |
- `min_t` _Float_ - minimum time value defaults to 0. |
|
|
1846 |
- `padding` _Float_ - padding time value defaults to 0. |
|
|
1847 |
- `dilation` _Float_ - dilation time value defaults to 0 ie the number of replicates. |
|
|
1848 |
- `exponent` _Float_ - Power parameter defaults to 1.0. |
|
|
1849 |
- `direction` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None). |
|
|
1850 |
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu"). |
|
|
1851 |
|
|
|
1852 |
<a id="mocoschedulesinference_time_schedulesPowerInferenceSchedulegenerate_schedule"></a> |
|
|
1853 |
|
|
|
1854 |
#### generate\_schedule |
|
|
1855 |
|
|
|
1856 |
```python |
|
|
1857 |
def generate_schedule( |
|
|
1858 |
nsteps: Optional[int] = None, |
|
|
1859 |
device: Optional[Union[str, torch.device]] = None) -> Tensor |
|
|
1860 |
``` |
|
|
1861 |
|
|
|
1862 |
Generate the power time schedule as a tensor. |
|
|
1863 |
|
|
|
1864 |
**Arguments**: |
|
|
1865 |
|
|
|
1866 |
- `nsteps` _Optional[int]_ - Number of time steps. If None uses the value from initialization. |
|
|
1867 |
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu"). |
|
|
1868 |
|
|
|
1869 |
|
|
|
1870 |
|
|
|
1871 |
**Returns**: |
|
|
1872 |
|
|
|
1873 |
- `Tensor` - A tensor of time steps. |
|
|
1874 |
- `Tensor` - A tensor of time steps. |
|
|
1875 |
|
|
|
1876 |
<a id="mocoschedulesinference_time_schedulesLogInferenceSchedule"></a> |
|
|
1877 |
|
|
|
1878 |
## LogInferenceSchedule Objects |
|
|
1879 |
|
|
|
1880 |
```python |
|
|
1881 |
class LogInferenceSchedule(ContinuousInferenceSchedule) |
|
|
1882 |
``` |
|
|
1883 |
|
|
|
1884 |
A log time schedule for inference, where time steps are generated by taking the logarithm of a uniform schedule. |
|
|
1885 |
|
|
|
1886 |
<a id="mocoschedulesinference_time_schedulesLogInferenceSchedule__init__"></a> |
|
|
1887 |
|
|
|
1888 |
#### \_\_init\_\_ |
|
|
1889 |
|
|
|
1890 |
```python |
|
|
1891 |
def __init__(nsteps: int, |
|
|
1892 |
inclusive_end: bool = False, |
|
|
1893 |
min_t: Float = 0, |
|
|
1894 |
padding: Float = 0, |
|
|
1895 |
dilation: Float = 0, |
|
|
1896 |
exponent: Float = -2.0, |
|
|
1897 |
direction: Union[TimeDirection, str] = TimeDirection.UNIFIED, |
|
|
1898 |
device: Union[str, torch.device] = "cpu") |
|
|
1899 |
``` |
|
|
1900 |
|
|
|
1901 |
Initialize the LogInferenceSchedule. |
|
|
1902 |
|
|
|
1903 |
Returns a log space time schedule. |
|
|
1904 |
|
|
|
1905 |
Which for 100 steps with default parameters is: |
|
|
1906 |
tensor([0.0000, 0.0455, 0.0889, 0.1303, 0.1699, 0.2077, 0.2439, 0.2783, 0.3113, |
|
|
1907 |
0.3427, 0.3728, 0.4015, 0.4288, 0.4550, 0.4800, 0.5039, 0.5266, 0.5484, |
|
|
1908 |
0.5692, 0.5890, 0.6080, 0.6261, 0.6434, 0.6599, 0.6756, 0.6907, 0.7051, |
|
|
1909 |
0.7188, 0.7319, 0.7444, 0.7564, 0.7678, 0.7787, 0.7891, 0.7991, 0.8086, |
|
|
1910 |
0.8176, 0.8263, 0.8346, 0.8425, 0.8500, 0.8572, 0.8641, 0.8707, 0.8769, |
|
|
1911 |
0.8829, 0.8887, 0.8941, 0.8993, 0.9043, 0.9091, 0.9136, 0.9180, 0.9221, |
|
|
1912 |
0.9261, 0.9299, 0.9335, 0.9369, 0.9402, 0.9434, 0.9464, 0.9492, 0.9520, |
|
|
1913 |
0.9546, 0.9571, 0.9595, 0.9618, 0.9639, 0.9660, 0.9680, 0.9699, 0.9717, |
|
|
1914 |
0.9734, 0.9751, 0.9767, 0.9782, 0.9796, 0.9810, 0.9823, 0.9835, 0.9847, |
|
|
1915 |
0.9859, 0.9870, 0.9880, 0.9890, 0.9899, 0.9909, 0.9917, 0.9925, 0.9933, |
|
|
1916 |
0.9941, 0.9948, 0.9955, 0.9962, 0.9968, 0.9974, 0.9980, 0.9985, 0.9990, |
|
|
1917 |
0.9995]) |
|
|
1918 |
|
|
|
1919 |
**Arguments**: |
|
|
1920 |
|
|
|
1921 |
- `nsteps` _int_ - Number of time steps. |
|
|
1922 |
- `inclusive_end` _bool_ - If True, include the end value (1.0) in the schedule otherwise ends at <1.0 (default is False). |
|
|
1923 |
- `min_t` _Float_ - minimum time value defaults to 0. |
|
|
1924 |
- `padding` _Float_ - padding time value defaults to 0. |
|
|
1925 |
- `dilation` _Float_ - dilation time value defaults to 0 ie the number of replicates. |
|
|
1926 |
- `exponent` _Float_ - log space exponent parameter defaults to -2.0. The lower number the more aggressive the acceleration of 0 to 0.9 will be thus having more steps from 0.9 to 1.0. |
|
|
1927 |
- `direction` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None). |
|
|
1928 |
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu"). |
|
|
1929 |
|
|
|
1930 |
<a id="mocoschedulesinference_time_schedulesLogInferenceSchedulegenerate_schedule"></a> |
|
|
1931 |
|
|
|
1932 |
#### generate\_schedule |
|
|
1933 |
|
|
|
1934 |
```python |
|
|
1935 |
def generate_schedule( |
|
|
1936 |
nsteps: Optional[int] = None, |
|
|
1937 |
device: Optional[Union[str, torch.device]] = None) -> Tensor |
|
|
1938 |
``` |
|
|
1939 |
|
|
|
1940 |
Generate the log time schedule as a tensor. |
|
|
1941 |
|
|
|
1942 |
**Arguments**: |
|
|
1943 |
|
|
|
1944 |
- `nsteps` _Optional[int]_ - Number of time steps. If None uses the value from initialization. |
|
|
1945 |
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu"). |
|
|
1946 |
|
|
|
1947 |
<a id="mocointerpolantscontinuous_timediscrete"></a> |
|
|
1948 |
|
|
|
1949 |
# bionemo.moco.interpolants.continuous\_time.discrete |
|
|
1950 |
|
|
|
1951 |
<a id="mocointerpolantscontinuous_timediscretemdlm"></a> |
|
|
1952 |
|
|
|
1953 |
# bionemo.moco.interpolants.continuous\_time.discrete.mdlm |
|
|
1954 |
|
|
|
1955 |
<a id="mocointerpolantscontinuous_timediscretemdlmMDLM"></a> |
|
|
1956 |
|
|
|
1957 |
## MDLM Objects |
|
|
1958 |
|
|
|
1959 |
```python |
|
|
1960 |
class MDLM(Interpolant) |
|
|
1961 |
``` |
|
|
1962 |
|
|
|
1963 |
A Masked discrete Diffusion Language Model (MDLM) interpolant. |
|
|
1964 |
|
|
|
1965 |
------- |
|
|
1966 |
|
|
|
1967 |
**Examples**: |
|
|
1968 |
|
|
|
1969 |
```python |
|
|
1970 |
>>> import torch |
|
|
1971 |
>>> from bionemo.bionemo.moco.distributions.prior.discrete.mask import DiscreteMaskedPrior |
|
|
1972 |
>>> from bionemo.bionemo.moco.distributions.time.uniform import UniformTimeDistribution |
|
|
1973 |
>>> from bionemo.bionemo.moco.interpolants.continuous_time.discrete.mdlm import MDLM |
|
|
1974 |
>>> from bionemo.bionemo.moco.schedules.noise.continuous_noise_transforms import CosineExpNoiseTransform |
|
|
1975 |
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import LinearTimeSchedule |
|
|
1976 |
|
|
|
1977 |
|
|
|
1978 |
mdlm = MDLM( |
|
|
1979 |
time_distribution = UniformTimeDistribution(discrete_time = False,...), |
|
|
1980 |
prior_distribution = DiscreteMaskedPrior(...), |
|
|
1981 |
noise_schedule = CosineExpNoiseTransform(...), |
|
|
1982 |
) |
|
|
1983 |
model = Model(...) |
|
|
1984 |
|
|
|
1985 |
# Training |
|
|
1986 |
for epoch in range(1000): |
|
|
1987 |
data = data_loader.get(...) |
|
|
1988 |
time = mdlm.sample_time(batch_size) |
|
|
1989 |
xt = mdlm.interpolate(data, time) |
|
|
1990 |
|
|
|
1991 |
logits = model(xt, time) |
|
|
1992 |
loss = mdlm.loss(logits, data, xt, time) |
|
|
1993 |
loss.backward() |
|
|
1994 |
|
|
|
1995 |
# Generation |
|
|
1996 |
x_pred = mdlm.sample_prior(data.shape) |
|
|
1997 |
schedule = LinearTimeSchedule(...) |
|
|
1998 |
inference_time = schedule.generate_schedule() |
|
|
1999 |
dts = schedue.discreteize() |
|
|
2000 |
for t, dt in zip(inference_time, dts): |
|
|
2001 |
time = torch.full((batch_size,), t) |
|
|
2002 |
logits = model(x_pred, time) |
|
|
2003 |
x_pred = mdlm.step(logits, time, x_pred, dt) |
|
|
2004 |
return x_pred |
|
|
2005 |
|
|
|
2006 |
``` |
|
|
2007 |
|
|
|
2008 |
<a id="mocointerpolantscontinuous_timediscretemdlmMDLM__init__"></a> |
|
|
2009 |
|
|
|
2010 |
#### \_\_init\_\_ |
|
|
2011 |
|
|
|
2012 |
```python |
|
|
2013 |
def __init__(time_distribution: TimeDistribution, |
|
|
2014 |
prior_distribution: DiscreteMaskedPrior, |
|
|
2015 |
noise_schedule: ContinuousExpNoiseTransform, |
|
|
2016 |
device: str = "cpu", |
|
|
2017 |
rng_generator: Optional[torch.Generator] = None) |
|
|
2018 |
``` |
|
|
2019 |
|
|
|
2020 |
Initialize the Masked Discrete Language Model (MDLM) interpolant. |
|
|
2021 |
|
|
|
2022 |
**Arguments**: |
|
|
2023 |
|
|
|
2024 |
- `time_distribution` _TimeDistribution_ - The distribution governing the time variable in the diffusion process. |
|
|
2025 |
- `prior_distribution` _DiscreteMaskedPrior_ - The prior distribution over the discrete token space, including masked tokens. |
|
|
2026 |
- `noise_schedule` _ContinuousExpNoiseTransform_ - The noise schedule defining the noise intensity as a function of time. |
|
|
2027 |
- `device` _str, optional_ - The device to use for computations. Defaults to "cpu". |
|
|
2028 |
- `rng_generator` _Optional[torch.Generator], optional_ - The random number generator for reproducibility. Defaults to None. |
|
|
2029 |
|
|
|
2030 |
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMinterpolate"></a> |
|
|
2031 |
|
|
|
2032 |
#### interpolate |
|
|
2033 |
|
|
|
2034 |
```python |
|
|
2035 |
def interpolate(data: Tensor, t: Tensor) |
|
|
2036 |
``` |
|
|
2037 |
|
|
|
2038 |
Get x(t) with given time t from noise and data. |
|
|
2039 |
|
|
|
2040 |
**Arguments**: |
|
|
2041 |
|
|
|
2042 |
- `data` _Tensor_ - target discrete ids |
|
|
2043 |
- `t` _Tensor_ - time |
|
|
2044 |
|
|
|
2045 |
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMforward_process"></a> |
|
|
2046 |
|
|
|
2047 |
#### forward\_process |
|
|
2048 |
|
|
|
2049 |
```python |
|
|
2050 |
def forward_process(data: Tensor, t: Tensor) -> Tensor |
|
|
2051 |
``` |
|
|
2052 |
|
|
|
2053 |
Apply the forward process to the data at time t. |
|
|
2054 |
|
|
|
2055 |
**Arguments**: |
|
|
2056 |
|
|
|
2057 |
- `data` _Tensor_ - target discrete ids |
|
|
2058 |
- `t` _Tensor_ - time |
|
|
2059 |
|
|
|
2060 |
|
|
|
2061 |
**Returns**: |
|
|
2062 |
|
|
|
2063 |
- `Tensor` - x(t) after applying the forward process |
|
|
2064 |
|
|
|
2065 |
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMloss"></a> |
|
|
2066 |
|
|
|
2067 |
#### loss |
|
|
2068 |
|
|
|
2069 |
```python |
|
|
2070 |
def loss(logits: Tensor, |
|
|
2071 |
target: Tensor, |
|
|
2072 |
xt: Tensor, |
|
|
2073 |
time: Tensor, |
|
|
2074 |
mask: Optional[Tensor] = None, |
|
|
2075 |
use_weight=True) |
|
|
2076 |
``` |
|
|
2077 |
|
|
|
2078 |
Calculate the cross-entropy loss between the model prediction and the target output. |
|
|
2079 |
|
|
|
2080 |
The loss is calculated between the batch x node x class logits and the target batch x node, |
|
|
2081 |
considering the current state of the discrete sequence `xt` at time `time`. |
|
|
2082 |
|
|
|
2083 |
If `use_weight` is True, the loss is weighted by the reduced form of the MDLM time weight for continuous NELBO, |
|
|
2084 |
as specified in equation 11 of https://arxiv.org/pdf/2406.07524. This weight is proportional to the derivative |
|
|
2085 |
of the noise schedule with respect to time, and is used to emphasize the importance of accurate predictions at |
|
|
2086 |
certain times in the diffusion process. |
|
|
2087 |
|
|
|
2088 |
**Arguments**: |
|
|
2089 |
|
|
|
2090 |
- `logits` _Tensor_ - The predicted output from the model, with shape batch x node x class. |
|
|
2091 |
- `target` _Tensor_ - The target output for the model prediction, with shape batch x node. |
|
|
2092 |
- `xt` _Tensor_ - The current state of the discrete sequence, with shape batch x node. |
|
|
2093 |
- `time` _Tensor_ - The time at which the loss is calculated. |
|
|
2094 |
- `mask` _Optional[Tensor], optional_ - The mask for the data point. Defaults to None. |
|
|
2095 |
- `use_weight` _bool, optional_ - Whether to use the MDLM time weight for the loss. Defaults to True. |
|
|
2096 |
|
|
|
2097 |
|
|
|
2098 |
**Returns**: |
|
|
2099 |
|
|
|
2100 |
- `Tensor` - The calculated loss batch tensor. |
|
|
2101 |
|
|
|
2102 |
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMstep"></a> |
|
|
2103 |
|
|
|
2104 |
#### step |
|
|
2105 |
|
|
|
2106 |
```python |
|
|
2107 |
def step(logits: Tensor, |
|
|
2108 |
t: Tensor, |
|
|
2109 |
xt: Tensor, |
|
|
2110 |
dt: Tensor, |
|
|
2111 |
temperature: float = 1.0) -> Tensor |
|
|
2112 |
``` |
|
|
2113 |
|
|
|
2114 |
Perform a single step of MDLM DDPM step. |
|
|
2115 |
|
|
|
2116 |
**Arguments**: |
|
|
2117 |
|
|
|
2118 |
- `logits` _Tensor_ - The input logits. |
|
|
2119 |
- `t` _Tensor_ - The current time step. |
|
|
2120 |
- `xt` _Tensor_ - The current state. |
|
|
2121 |
- `dt` _Tensor_ - The time step increment. |
|
|
2122 |
- `temperature` _float_ - Softmax temperature defaults to 1.0. |
|
|
2123 |
|
|
|
2124 |
|
|
|
2125 |
**Returns**: |
|
|
2126 |
|
|
|
2127 |
- `Tensor` - The updated state. |
|
|
2128 |
|
|
|
2129 |
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMget_num_steps_confidence"></a> |
|
|
2130 |
|
|
|
2131 |
#### get\_num\_steps\_confidence |
|
|
2132 |
|
|
|
2133 |
```python |
|
|
2134 |
def get_num_steps_confidence(xt: Tensor, num_tokens_unmask: int = 1) |
|
|
2135 |
``` |
|
|
2136 |
|
|
|
2137 |
Calculate the maximum number of steps with confidence. |
|
|
2138 |
|
|
|
2139 |
This method computes the maximum count of occurrences where the input tensor `xt` matches the `mask_index` |
|
|
2140 |
along the last dimension (-1). The result is returned as a single float value. |
|
|
2141 |
|
|
|
2142 |
**Arguments**: |
|
|
2143 |
|
|
|
2144 |
- `xt` _Tensor_ - Input tensor to evaluate against the mask index. |
|
|
2145 |
- `num_tokens_unmask` _int_ - number of tokens to unamsk at each step. |
|
|
2146 |
|
|
|
2147 |
|
|
|
2148 |
**Returns**: |
|
|
2149 |
|
|
|
2150 |
- `float` - The maximum number of steps with confidence (i.e., matching the mask index). |
|
|
2151 |
|
|
|
2152 |
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMstep_confidence"></a> |
|
|
2153 |
|
|
|
2154 |
#### step\_confidence |
|
|
2155 |
|
|
|
2156 |
```python |
|
|
2157 |
def step_confidence(logits: Tensor, |
|
|
2158 |
xt: Tensor, |
|
|
2159 |
curr_step: int, |
|
|
2160 |
num_steps: int, |
|
|
2161 |
logit_temperature: float = 1.0, |
|
|
2162 |
randomness: float = 1.0, |
|
|
2163 |
confidence_temperature: float = 1.0, |
|
|
2164 |
num_tokens_unmask: int = 1) -> Tensor |
|
|
2165 |
``` |
|
|
2166 |
|
|
|
2167 |
Update the input sequence xt by sampling from the predicted logits and adding Gumbel noise. |
|
|
2168 |
|
|
|
2169 |
Method taken from GenMol Lee et al. https://arxiv.org/abs/2501.06158 |
|
|
2170 |
|
|
|
2171 |
**Arguments**: |
|
|
2172 |
|
|
|
2173 |
- `logits` - Predicted logits |
|
|
2174 |
- `xt` - Input sequence |
|
|
2175 |
- `curr_step` - Current step |
|
|
2176 |
- `num_steps` - Total number of steps |
|
|
2177 |
- `logit_temperature` - Temperature for softmax over logits |
|
|
2178 |
- `randomness` - Scale for Gumbel noise |
|
|
2179 |
- `confidence_temperature` - Temperature for Gumbel confidence |
|
|
2180 |
- `num_tokens_unmask` - number of tokens to unmask each step |
|
|
2181 |
|
|
|
2182 |
|
|
|
2183 |
**Returns**: |
|
|
2184 |
|
|
|
2185 |
Updated input sequence xt unmasking num_tokens_unmask token each step. |
|
|
2186 |
|
|
|
2187 |
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMstep_argmax"></a> |
|
|
2188 |
|
|
|
2189 |
#### step\_argmax |
|
|
2190 |
|
|
|
2191 |
```python |
|
|
2192 |
def step_argmax(model_out: Tensor) |
|
|
2193 |
``` |
|
|
2194 |
|
|
|
2195 |
Returns the index of the maximum value in the last dimension of the model output. |
|
|
2196 |
|
|
|
2197 |
**Arguments**: |
|
|
2198 |
|
|
|
2199 |
- `model_out` _Tensor_ - The output of the model. |
|
|
2200 |
|
|
|
2201 |
|
|
|
2202 |
**Returns**: |
|
|
2203 |
|
|
|
2204 |
- `Tensor` - The index of the maximum value in the last dimension of the model output. |
|
|
2205 |
|
|
|
2206 |
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMcalculate_score"></a> |
|
|
2207 |
|
|
|
2208 |
#### calculate\_score |
|
|
2209 |
|
|
|
2210 |
```python |
|
|
2211 |
def calculate_score(logits, x, t) |
|
|
2212 |
``` |
|
|
2213 |
|
|
|
2214 |
Returns score of the given sample x at time t with the corresponding model output logits. |
|
|
2215 |
|
|
|
2216 |
**Arguments**: |
|
|
2217 |
|
|
|
2218 |
- `logits` _Tensor_ - The output of the model. |
|
|
2219 |
- `x` _Tensor_ - The current data point. |
|
|
2220 |
- `t` _Tensor_ - The current time. |
|
|
2221 |
|
|
|
2222 |
|
|
|
2223 |
**Returns**: |
|
|
2224 |
|
|
|
2225 |
- `Tensor` - The score defined in Appendix C.3 Equation 76 of MDLM. |
|
|
2226 |
|
|
|
2227 |
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMstep_self_path_planning"></a> |
|
|
2228 |
|
|
|
2229 |
#### step\_self\_path\_planning |
|
|
2230 |
|
|
|
2231 |
```python |
|
|
2232 |
def step_self_path_planning(logits: Tensor, |
|
|
2233 |
xt: Tensor, |
|
|
2234 |
t: Tensor, |
|
|
2235 |
curr_step: int, |
|
|
2236 |
num_steps: int, |
|
|
2237 |
logit_temperature: float = 1.0, |
|
|
2238 |
randomness: float = 1.0, |
|
|
2239 |
confidence_temperature: float = 1.0, |
|
|
2240 |
score_type: Literal["confidence", |
|
|
2241 |
"random"] = "confidence", |
|
|
2242 |
fix_mask: Optional[Tensor] = None) -> Tensor |
|
|
2243 |
``` |
|
|
2244 |
|
|
|
2245 |
Self Path Planning (P2) Sampling from Peng et al. https://arxiv.org/html/2502.03540v1. |
|
|
2246 |
|
|
|
2247 |
**Arguments**: |
|
|
2248 |
|
|
|
2249 |
- `logits` _Tensor_ - Predicted logits for sampling. |
|
|
2250 |
- `xt` _Tensor_ - Input sequence to be updated. |
|
|
2251 |
- `t` _Tensor_ - Time tensor (e.g., time steps or temporal info). |
|
|
2252 |
- `curr_step` _int_ - Current iteration in the planning process. |
|
|
2253 |
- `num_steps` _int_ - Total number of planning steps. |
|
|
2254 |
- `logit_temperature` _float_ - Temperature for logits (default: 1.0). |
|
|
2255 |
- `randomness` _float_ - Introduced randomness level (default: 1.0). |
|
|
2256 |
- `confidence_temperature` _float_ - Temperature for confidence scoring (default: 1.0). |
|
|
2257 |
- `score_type` _Literal["confidence", "random"]_ - Sampling score type (default: "confidence"). |
|
|
2258 |
- `fix_mask` _Optional[Tensor]_ - inital mask where True when not a mask tokens (default: None). |
|
|
2259 |
|
|
|
2260 |
|
|
|
2261 |
**Returns**: |
|
|
2262 |
|
|
|
2263 |
- `Tensor` - Updated input sequence xt after iterative unmasking. |
|
|
2264 |
|
|
|
2265 |
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMtopk_lowest_masking"></a> |
|
|
2266 |
|
|
|
2267 |
#### topk\_lowest\_masking |
|
|
2268 |
|
|
|
2269 |
```python |
|
|
2270 |
def topk_lowest_masking(scores: Tensor, cutoff_len: Tensor) |
|
|
2271 |
``` |
|
|
2272 |
|
|
|
2273 |
Generates a mask for the lowest scoring elements up to a specified cutoff length. |
|
|
2274 |
|
|
|
2275 |
**Arguments**: |
|
|
2276 |
|
|
|
2277 |
- `scores` _Tensor_ - Input scores tensor with shape (... , num_elements) |
|
|
2278 |
- `cutoff_len` _Tensor_ - Number of lowest-scoring elements to mask (per batch element) |
|
|
2279 |
|
|
|
2280 |
|
|
|
2281 |
**Returns**: |
|
|
2282 |
|
|
|
2283 |
- `Tensor` - Boolean mask tensor with same shape as `scores`, where `True` indicates |
|
|
2284 |
the corresponding element is among the `cutoff_len` lowest scores. |
|
|
2285 |
|
|
|
2286 |
|
|
|
2287 |
**Example**: |
|
|
2288 |
|
|
|
2289 |
>>> scores = torch.tensor([[0.9, 0.8, 0.1, 0.05], [0.7, 0.4, 0.3, 0.2]]) |
|
|
2290 |
>>> cutoff_len = 2 |
|
|
2291 |
>>> mask = topk_lowest_masking(scores, cutoff_len) |
|
|
2292 |
>>> print(mask) |
|
|
2293 |
tensor([[False, False, True, True], |
|
|
2294 |
[False, True, True, False]]) |
|
|
2295 |
|
|
|
2296 |
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMstochastic_sample_from_categorical"></a> |
|
|
2297 |
|
|
|
2298 |
#### stochastic\_sample\_from\_categorical |
|
|
2299 |
|
|
|
2300 |
```python |
|
|
2301 |
def stochastic_sample_from_categorical(logits: Tensor, |
|
|
2302 |
temperature: float = 1.0, |
|
|
2303 |
noise_scale: float = 1.0) |
|
|
2304 |
``` |
|
|
2305 |
|
|
|
2306 |
Stochastically samples from a categorical distribution defined by input logits, with optional temperature and noise scaling for diverse sampling. |
|
|
2307 |
|
|
|
2308 |
**Arguments**: |
|
|
2309 |
|
|
|
2310 |
- `logits` _Tensor_ - Input logits tensor with shape (... , num_categories) |
|
|
2311 |
- `temperature` _float, optional_ - Softmax temperature. Higher values produce more uniform samples. Defaults to 1.0. |
|
|
2312 |
- `noise_scale` _float, optional_ - Scale for Gumbel noise. Higher values produce more diverse samples. Defaults to 1.0. |
|
|
2313 |
|
|
|
2314 |
|
|
|
2315 |
**Returns**: |
|
|
2316 |
|
|
|
2317 |
tuple: |
|
|
2318 |
- **tokens** (LongTensor): Sampling result (category indices) with shape (... , ) |
|
|
2319 |
- **scores** (Tensor): Corresponding log-softmax scores for the sampled tokens, with shape (... , ) |
|
|
2320 |
|
|
|
2321 |
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matching"></a> |
|
|
2322 |
|
|
|
2323 |
# bionemo.moco.interpolants.continuous\_time.discrete.discrete\_flow\_matching |
|
|
2324 |
|
|
|
2325 |
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matchingDiscreteFlowMatcher"></a> |
|
|
2326 |
|
|
|
2327 |
## DiscreteFlowMatcher Objects |
|
|
2328 |
|
|
|
2329 |
```python |
|
|
2330 |
class DiscreteFlowMatcher(Interpolant) |
|
|
2331 |
``` |
|
|
2332 |
|
|
|
2333 |
A Discrete Flow Model (DFM) interpolant. |
|
|
2334 |
|
|
|
2335 |
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matchingDiscreteFlowMatcher__init__"></a> |
|
|
2336 |
|
|
|
2337 |
#### \_\_init\_\_ |
|
|
2338 |
|
|
|
2339 |
```python |
|
|
2340 |
def __init__(time_distribution: TimeDistribution, |
|
|
2341 |
prior_distribution: DiscretePriorDistribution, |
|
|
2342 |
device: str = "cpu", |
|
|
2343 |
eps: Float = 1e-5, |
|
|
2344 |
rng_generator: Optional[torch.Generator] = None) |
|
|
2345 |
``` |
|
|
2346 |
|
|
|
2347 |
Initialize the DFM interpolant. |
|
|
2348 |
|
|
|
2349 |
**Arguments**: |
|
|
2350 |
|
|
|
2351 |
- `time_distribution` _TimeDistribution_ - The time distribution for the diffusion process. |
|
|
2352 |
- `prior_distribution` _DiscretePriorDistribution_ - The prior distribution for the discrete masked tokens. |
|
|
2353 |
- `device` _str, optional_ - The device to use for computations. Defaults to "cpu". |
|
|
2354 |
- `eps` - small Float to prevent dividing by zero. |
|
|
2355 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
2356 |
|
|
|
2357 |
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matchingDiscreteFlowMatcherinterpolate"></a> |
|
|
2358 |
|
|
|
2359 |
#### interpolate |
|
|
2360 |
|
|
|
2361 |
```python |
|
|
2362 |
def interpolate(data: Tensor, t: Tensor, noise: Tensor) |
|
|
2363 |
``` |
|
|
2364 |
|
|
|
2365 |
Get x(t) with given time t from noise and data. |
|
|
2366 |
|
|
|
2367 |
**Arguments**: |
|
|
2368 |
|
|
|
2369 |
- `data` _Tensor_ - target discrete ids |
|
|
2370 |
- `t` _Tensor_ - time |
|
|
2371 |
- `noise` - tensor noise ids |
|
|
2372 |
|
|
|
2373 |
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matchingDiscreteFlowMatcherloss"></a> |
|
|
2374 |
|
|
|
2375 |
#### loss |
|
|
2376 |
|
|
|
2377 |
```python |
|
|
2378 |
def loss(logits: Tensor, |
|
|
2379 |
target: Tensor, |
|
|
2380 |
time: Optional[Tensor] = None, |
|
|
2381 |
mask: Optional[Tensor] = None, |
|
|
2382 |
use_weight: Bool = False) |
|
|
2383 |
``` |
|
|
2384 |
|
|
|
2385 |
Calculate the cross-entropy loss between the model prediction and the target output. |
|
|
2386 |
|
|
|
2387 |
The loss is calculated between the batch x node x class logits and the target batch x node. |
|
|
2388 |
If using a masked prior please pass in the correct mask to calculate loss values on only masked states. |
|
|
2389 |
i.e. mask = data_mask * is_masked_state which is calculated with self.prior_dist.is_masked(xt)) |
|
|
2390 |
|
|
|
2391 |
If `use_weight` is True, the loss is weighted by 1/(1-t) defined in equation 24 in Appndix C. of https://arxiv.org/pdf/2402.04997 |
|
|
2392 |
|
|
|
2393 |
**Arguments**: |
|
|
2394 |
|
|
|
2395 |
- `logits` _Tensor_ - The predicted output from the model, with shape batch x node x class. |
|
|
2396 |
- `target` _Tensor_ - The target output for the model prediction, with shape batch x node. |
|
|
2397 |
- `time` _Tensor_ - The time at which the loss is calculated. |
|
|
2398 |
- `mask` _Optional[Tensor], optional_ - The mask for the data point. Defaults to None. |
|
|
2399 |
- `use_weight` _bool, optional_ - Whether to use the DFM time weight for the loss. Defaults to True. |
|
|
2400 |
|
|
|
2401 |
|
|
|
2402 |
**Returns**: |
|
|
2403 |
|
|
|
2404 |
- `Tensor` - The calculated loss batch tensor. |
|
|
2405 |
|
|
|
2406 |
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matchingDiscreteFlowMatcherstep"></a> |
|
|
2407 |
|
|
|
2408 |
#### step |
|
|
2409 |
|
|
|
2410 |
```python |
|
|
2411 |
def step(logits: Tensor, |
|
|
2412 |
t: Tensor, |
|
|
2413 |
xt: Tensor, |
|
|
2414 |
dt: Tensor | float, |
|
|
2415 |
temperature: Float = 1.0, |
|
|
2416 |
stochasticity: Float = 1.0) -> Tensor |
|
|
2417 |
``` |
|
|
2418 |
|
|
|
2419 |
Perform a single step of DFM euler updates. |
|
|
2420 |
|
|
|
2421 |
**Arguments**: |
|
|
2422 |
|
|
|
2423 |
- `logits` _Tensor_ - The input logits. |
|
|
2424 |
- `t` _Tensor_ - The current time step. |
|
|
2425 |
- `xt` _Tensor_ - The current state. |
|
|
2426 |
- `dt` _Tensor | float_ - The time step increment. |
|
|
2427 |
- `temperature` _Float, optional_ - The temperature for the softmax calculation. Defaults to 1.0. |
|
|
2428 |
- `stochasticity` _Float, optional_ - The stochasticity value for the step calculation. Defaults to 1.0. |
|
|
2429 |
|
|
|
2430 |
|
|
|
2431 |
**Returns**: |
|
|
2432 |
|
|
|
2433 |
- `Tensor` - The updated state. |
|
|
2434 |
|
|
|
2435 |
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matchingDiscreteFlowMatcherstep_purity"></a> |
|
|
2436 |
|
|
|
2437 |
#### step\_purity |
|
|
2438 |
|
|
|
2439 |
```python |
|
|
2440 |
def step_purity(logits: Tensor, |
|
|
2441 |
t: Tensor, |
|
|
2442 |
xt: Tensor, |
|
|
2443 |
dt: Tensor | float, |
|
|
2444 |
temperature: Float = 1.0, |
|
|
2445 |
stochasticity: Float = 1.0) -> Tensor |
|
|
2446 |
``` |
|
|
2447 |
|
|
|
2448 |
Perform a single step of purity sampling. |
|
|
2449 |
|
|
|
2450 |
https://github.com/jasonkyuyim/multiflow/blob/6278899970523bad29953047e7a42b32a41dc813/multiflow/data/interpolant.py#L346 |
|
|
2451 |
Here's a high-level overview of what the function does: |
|
|
2452 |
TODO: check if the -1e9 and 1e-9 are small enough or using torch.inf would be better |
|
|
2453 |
|
|
|
2454 |
1. Preprocessing: |
|
|
2455 |
Checks if dt is a float and converts it to a tensor if necessary. |
|
|
2456 |
Pads t and dt to match the shape of xt. |
|
|
2457 |
Checks if the mask_index is valid (i.e., within the range of possible discrete values). |
|
|
2458 |
2. Masking: |
|
|
2459 |
Sets the logits corresponding to the mask_index to a low value (-1e9) to effectively mask out those values. |
|
|
2460 |
Computes the softmax probabilities of the logits. |
|
|
2461 |
Sets the probability of the mask_index to a small value (1e-9) to avoid numerical issues. |
|
|
2462 |
3.Purity sampling: |
|
|
2463 |
Computes the maximum log probabilities of the softmax distribution. |
|
|
2464 |
Computes the indices of the top-number_to_unmask samples with the highest log probabilities. |
|
|
2465 |
Uses these indices to sample new values from the original distribution. |
|
|
2466 |
4. Unmasking and updating: |
|
|
2467 |
Creates a mask to select the top-number_to_unmask samples. |
|
|
2468 |
Uses this mask to update the current state xt with the new samples. |
|
|
2469 |
5. Re-masking: |
|
|
2470 |
Generates a new mask to randomly re-mask some of the updated samples. |
|
|
2471 |
Applies this mask to the updated state xt. |
|
|
2472 |
|
|
|
2473 |
**Arguments**: |
|
|
2474 |
|
|
|
2475 |
- `logits` _Tensor_ - The input logits. |
|
|
2476 |
- `t` _Tensor_ - The current time step. |
|
|
2477 |
- `xt` _Tensor_ - The current state. |
|
|
2478 |
- `dt` _Tensor_ - The time step increment. |
|
|
2479 |
- `temperature` _Float, optional_ - The temperature for the softmax calculation. Defaults to 1.0. |
|
|
2480 |
- `stochasticity` _Float, optional_ - The stochasticity value for the step calculation. Defaults to 1.0. |
|
|
2481 |
|
|
|
2482 |
|
|
|
2483 |
**Returns**: |
|
|
2484 |
|
|
|
2485 |
- `Tensor` - The updated state. |
|
|
2486 |
|
|
|
2487 |
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matchingDiscreteFlowMatcherstep_argmax"></a> |
|
|
2488 |
|
|
|
2489 |
#### step\_argmax |
|
|
2490 |
|
|
|
2491 |
```python |
|
|
2492 |
def step_argmax(model_out: Tensor) |
|
|
2493 |
``` |
|
|
2494 |
|
|
|
2495 |
Returns the index of the maximum value in the last dimension of the model output. |
|
|
2496 |
|
|
|
2497 |
**Arguments**: |
|
|
2498 |
|
|
|
2499 |
- `model_out` _Tensor_ - The output of the model. |
|
|
2500 |
|
|
|
2501 |
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matchingDiscreteFlowMatcherstep_simple_sample"></a> |
|
|
2502 |
|
|
|
2503 |
#### step\_simple\_sample |
|
|
2504 |
|
|
|
2505 |
```python |
|
|
2506 |
def step_simple_sample(model_out: Tensor, |
|
|
2507 |
temperature: float = 1.0, |
|
|
2508 |
num_samples: int = 1) |
|
|
2509 |
``` |
|
|
2510 |
|
|
|
2511 |
Samples from the model output logits. Leads to more diversity than step_argmax. |
|
|
2512 |
|
|
|
2513 |
**Arguments**: |
|
|
2514 |
|
|
|
2515 |
- `model_out` _Tensor_ - The output of the model. |
|
|
2516 |
- `temperature` _Float, optional_ - The temperature for the softmax calculation. Defaults to 1.0. |
|
|
2517 |
- `num_samples` _int_ - Number of samples to return |
|
|
2518 |
|
|
|
2519 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationot_sampler"></a> |
|
|
2520 |
|
|
|
2521 |
# bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation.ot\_sampler |
|
|
2522 |
|
|
|
2523 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationot_samplerOTSampler"></a> |
|
|
2524 |
|
|
|
2525 |
## OTSampler Objects |
|
|
2526 |
|
|
|
2527 |
```python |
|
|
2528 |
class OTSampler() |
|
|
2529 |
``` |
|
|
2530 |
|
|
|
2531 |
Sampler for Exact Mini-batch Optimal Transport Plan. |
|
|
2532 |
|
|
|
2533 |
OTSampler implements sampling coordinates according to an OT plan (wrt squared Euclidean cost) |
|
|
2534 |
with different implementations of the plan calculation. Code is adapted from https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/optimal_transport.py |
|
|
2535 |
|
|
|
2536 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationot_samplerOTSampler__init__"></a> |
|
|
2537 |
|
|
|
2538 |
#### \_\_init\_\_ |
|
|
2539 |
|
|
|
2540 |
```python |
|
|
2541 |
def __init__(method: str = "exact", |
|
|
2542 |
device: Union[str, torch.device] = "cpu", |
|
|
2543 |
num_threads: int = 1) -> None |
|
|
2544 |
``` |
|
|
2545 |
|
|
|
2546 |
Initialize the OTSampler class. |
|
|
2547 |
|
|
|
2548 |
**Arguments**: |
|
|
2549 |
|
|
|
2550 |
- `method` _str_ - Choose which optimal transport solver you would like to use. Currently only support exact OT solvers (pot.emd). |
|
|
2551 |
- `device` _Union[str, torch.device], optional_ - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu". |
|
|
2552 |
- `num_threads` _Union[int, str], optional_ - Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1. |
|
|
2553 |
|
|
|
2554 |
|
|
|
2555 |
**Raises**: |
|
|
2556 |
|
|
|
2557 |
- `ValueError` - If the OT solver is not documented. |
|
|
2558 |
- `NotImplementedError` - If the OT solver is not implemented. |
|
|
2559 |
|
|
|
2560 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationot_samplerOTSamplerto_device"></a> |
|
|
2561 |
|
|
|
2562 |
#### to\_device |
|
|
2563 |
|
|
|
2564 |
```python |
|
|
2565 |
def to_device(device: str) |
|
|
2566 |
``` |
|
|
2567 |
|
|
|
2568 |
Moves all internal tensors to the specified device and updates the `self.device` attribute. |
|
|
2569 |
|
|
|
2570 |
**Arguments**: |
|
|
2571 |
|
|
|
2572 |
- `device` _str_ - The device to move the tensors to (e.g. "cpu", "cuda:0"). |
|
|
2573 |
|
|
|
2574 |
|
|
|
2575 |
**Notes**: |
|
|
2576 |
|
|
|
2577 |
This method is used to transfer the internal state of the OTSampler to a different device. |
|
|
2578 |
It updates the `self.device` attribute to reflect the new device and moves all internal tensors to the specified device. |
|
|
2579 |
|
|
|
2580 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationot_samplerOTSamplersample_map"></a> |
|
|
2581 |
|
|
|
2582 |
#### sample\_map |
|
|
2583 |
|
|
|
2584 |
```python |
|
|
2585 |
def sample_map(pi: Tensor, |
|
|
2586 |
batch_size: int, |
|
|
2587 |
replace: Bool = False) -> Tuple[Tensor, Tensor] |
|
|
2588 |
``` |
|
|
2589 |
|
|
|
2590 |
Draw source and target samples from pi $(x,z) \sim \pi$. |
|
|
2591 |
|
|
|
2592 |
**Arguments**: |
|
|
2593 |
|
|
|
2594 |
- `pi` _Tensor_ - shape (bs, bs), the OT matrix between noise and data in minibatch. |
|
|
2595 |
- `batch_size` _int_ - The batch size of the minibatch. |
|
|
2596 |
- `replace` _bool_ - sampling w/ or w/o replacement from the OT plan, default to False. |
|
|
2597 |
|
|
|
2598 |
|
|
|
2599 |
**Returns**: |
|
|
2600 |
|
|
|
2601 |
- `Tuple` - tuple of 2 tensors, represents the indices of noise and data samples from pi. |
|
|
2602 |
|
|
|
2603 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationot_samplerOTSamplerget_ot_matrix"></a> |
|
|
2604 |
|
|
|
2605 |
#### get\_ot\_matrix |
|
|
2606 |
|
|
|
2607 |
```python |
|
|
2608 |
def get_ot_matrix(x0: Tensor, |
|
|
2609 |
x1: Tensor, |
|
|
2610 |
mask: Optional[Tensor] = None) -> Tensor |
|
|
2611 |
``` |
|
|
2612 |
|
|
|
2613 |
Compute the OT matrix between a source and a target minibatch. |
|
|
2614 |
|
|
|
2615 |
**Arguments**: |
|
|
2616 |
|
|
|
2617 |
- `x0` _Tensor_ - shape (bs, *dim), noise from source minibatch. |
|
|
2618 |
- `x1` _Tensor_ - shape (bs, *dim), data from source minibatch. |
|
|
2619 |
- `mask` _Optional[Tensor], optional_ - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None. |
|
|
2620 |
|
|
|
2621 |
|
|
|
2622 |
**Returns**: |
|
|
2623 |
|
|
|
2624 |
- `p` _Tensor_ - shape (bs, bs), the OT matrix between noise and data in minibatch. |
|
|
2625 |
|
|
|
2626 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationot_samplerOTSamplerapply_augmentation"></a> |
|
|
2627 |
|
|
|
2628 |
#### apply\_augmentation |
|
|
2629 |
|
|
|
2630 |
```python |
|
|
2631 |
def apply_augmentation( |
|
|
2632 |
x0: Tensor, |
|
|
2633 |
x1: Tensor, |
|
|
2634 |
mask: Optional[Tensor] = None, |
|
|
2635 |
replace: Bool = False, |
|
|
2636 |
sort: Optional[Literal["noise", "x0", "data", "x1"]] = "x0" |
|
|
2637 |
) -> Tuple[Tensor, Tensor, Optional[Tensor]] |
|
|
2638 |
``` |
|
|
2639 |
|
|
|
2640 |
Sample indices for noise and data in minibatch according to OT plan. |
|
|
2641 |
|
|
|
2642 |
Compute the OT plan $\pi$ (wrt squared Euclidean cost) between a source and a target |
|
|
2643 |
minibatch and draw source and target samples from pi $(x,z) \sim \pi$. |
|
|
2644 |
|
|
|
2645 |
**Arguments**: |
|
|
2646 |
|
|
|
2647 |
- `x0` _Tensor_ - shape (bs, *dim), noise from source minibatch. |
|
|
2648 |
- `x1` _Tensor_ - shape (bs, *dim), data from source minibatch. |
|
|
2649 |
- `mask` _Optional[Tensor], optional_ - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None. |
|
|
2650 |
- `replace` _bool_ - sampling w/ or w/o replacement from the OT plan, default to False. |
|
|
2651 |
- `sort` _str_ - Optional Literal string to sort either x1 or x0 based on the input. |
|
|
2652 |
|
|
|
2653 |
|
|
|
2654 |
**Returns**: |
|
|
2655 |
|
|
|
2656 |
- `Tuple` - tuple of 2 tensors or 3 tensors if mask is used, represents the noise (plus mask) and data samples following OT plan pi. |
|
|
2657 |
|
|
|
2658 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_sampler"></a> |
|
|
2659 |
|
|
|
2660 |
# bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation.equivariant\_ot\_sampler |
|
|
2661 |
|
|
|
2662 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_samplerEquivariantOTSampler"></a> |
|
|
2663 |
|
|
|
2664 |
## EquivariantOTSampler Objects |
|
|
2665 |
|
|
|
2666 |
```python |
|
|
2667 |
class EquivariantOTSampler() |
|
|
2668 |
``` |
|
|
2669 |
|
|
|
2670 |
Sampler for Mini-batch Optimal Transport Plan with cost calculated after Kabsch alignment. |
|
|
2671 |
|
|
|
2672 |
EquivariantOTSampler implements sampling coordinates according to an OT plan |
|
|
2673 |
(wrt squared Euclidean cost after Kabsch alignment) with different implementations of the plan calculation. |
|
|
2674 |
|
|
|
2675 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_samplerEquivariantOTSampler__init__"></a> |
|
|
2676 |
|
|
|
2677 |
#### \_\_init\_\_ |
|
|
2678 |
|
|
|
2679 |
```python |
|
|
2680 |
def __init__(method: str = "exact", |
|
|
2681 |
device: Union[str, torch.device] = "cpu", |
|
|
2682 |
num_threads: int = 1) -> None |
|
|
2683 |
``` |
|
|
2684 |
|
|
|
2685 |
Initialize the OTSampler class. |
|
|
2686 |
|
|
|
2687 |
**Arguments**: |
|
|
2688 |
|
|
|
2689 |
- `method` _str_ - Choose which optimal transport solver you would like to use. Currently only support exact OT solvers (pot.emd). |
|
|
2690 |
- `device` _Union[str, torch.device], optional_ - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu". |
|
|
2691 |
- `num_threads` _Union[int, str], optional_ - Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1. |
|
|
2692 |
|
|
|
2693 |
|
|
|
2694 |
**Raises**: |
|
|
2695 |
|
|
|
2696 |
- `ValueError` - If the OT solver is not documented. |
|
|
2697 |
- `NotImplementedError` - If the OT solver is not implemented. |
|
|
2698 |
|
|
|
2699 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_samplerEquivariantOTSamplerto_device"></a> |
|
|
2700 |
|
|
|
2701 |
#### to\_device |
|
|
2702 |
|
|
|
2703 |
```python |
|
|
2704 |
def to_device(device: str) |
|
|
2705 |
``` |
|
|
2706 |
|
|
|
2707 |
Moves all internal tensors to the specified device and updates the `self.device` attribute. |
|
|
2708 |
|
|
|
2709 |
**Arguments**: |
|
|
2710 |
|
|
|
2711 |
- `device` _str_ - The device to move the tensors to (e.g. "cpu", "cuda:0"). |
|
|
2712 |
|
|
|
2713 |
|
|
|
2714 |
**Notes**: |
|
|
2715 |
|
|
|
2716 |
This method is used to transfer the internal state of the OTSampler to a different device. |
|
|
2717 |
It updates the `self.device` attribute to reflect the new device and moves all internal tensors to the specified device. |
|
|
2718 |
|
|
|
2719 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_samplerEquivariantOTSamplersample_map"></a> |
|
|
2720 |
|
|
|
2721 |
#### sample\_map |
|
|
2722 |
|
|
|
2723 |
```python |
|
|
2724 |
def sample_map(pi: Tensor, |
|
|
2725 |
batch_size: int, |
|
|
2726 |
replace: Bool = False) -> Tuple[Tensor, Tensor] |
|
|
2727 |
``` |
|
|
2728 |
|
|
|
2729 |
Draw source and target samples from pi $(x,z) \sim \pi$. |
|
|
2730 |
|
|
|
2731 |
**Arguments**: |
|
|
2732 |
|
|
|
2733 |
- `pi` _Tensor_ - shape (bs, bs), the OT matrix between noise and data in minibatch. |
|
|
2734 |
- `batch_size` _int_ - The batch size of the minibatch. |
|
|
2735 |
- `replace` _bool_ - sampling w/ or w/o replacement from the OT plan, default to False. |
|
|
2736 |
|
|
|
2737 |
|
|
|
2738 |
**Returns**: |
|
|
2739 |
|
|
|
2740 |
- `Tuple` - tuple of 2 tensors, represents the indices of noise and data samples from pi. |
|
|
2741 |
|
|
|
2742 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_samplerEquivariantOTSamplerkabsch_align"></a> |
|
|
2743 |
|
|
|
2744 |
#### kabsch\_align |
|
|
2745 |
|
|
|
2746 |
```python |
|
|
2747 |
def kabsch_align(target: Tensor, noise: Tensor) -> Tensor |
|
|
2748 |
``` |
|
|
2749 |
|
|
|
2750 |
Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise. |
|
|
2751 |
|
|
|
2752 |
**Arguments**: |
|
|
2753 |
|
|
|
2754 |
- `target` _Tensor_ - shape (N, *dim), data from source minibatch. |
|
|
2755 |
- `noise` _Tensor_ - shape (N, *dim), noise from source minibatch. |
|
|
2756 |
|
|
|
2757 |
|
|
|
2758 |
**Returns**: |
|
|
2759 |
|
|
|
2760 |
- `R` _Tensor_ - shape (*dim, *dim), the rotation matrix. |
|
|
2761 |
|
|
|
2762 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_samplerEquivariantOTSamplerget_ot_matrix"></a> |
|
|
2763 |
|
|
|
2764 |
#### get\_ot\_matrix |
|
|
2765 |
|
|
|
2766 |
```python |
|
|
2767 |
def get_ot_matrix(x0: Tensor, |
|
|
2768 |
x1: Tensor, |
|
|
2769 |
mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor] |
|
|
2770 |
``` |
|
|
2771 |
|
|
|
2772 |
Compute the OT matrix between a source and a target minibatch. |
|
|
2773 |
|
|
|
2774 |
**Arguments**: |
|
|
2775 |
|
|
|
2776 |
- `x0` _Tensor_ - shape (bs, *dim), noise from source minibatch. |
|
|
2777 |
- `x1` _Tensor_ - shape (bs, *dim), data from source minibatch. |
|
|
2778 |
- `mask` _Optional[Tensor], optional_ - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None. |
|
|
2779 |
|
|
|
2780 |
|
|
|
2781 |
**Returns**: |
|
|
2782 |
|
|
|
2783 |
- `p` _Tensor_ - shape (bs, bs), the OT matrix between noise and data in minibatch. |
|
|
2784 |
- `Rs` _Tensor_ - shape (bs, bs, *dim, *dim), the rotation matrix between noise and data in minibatch. |
|
|
2785 |
|
|
|
2786 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_samplerEquivariantOTSamplerapply_augmentation"></a> |
|
|
2787 |
|
|
|
2788 |
#### apply\_augmentation |
|
|
2789 |
|
|
|
2790 |
```python |
|
|
2791 |
def apply_augmentation( |
|
|
2792 |
x0: Tensor, |
|
|
2793 |
x1: Tensor, |
|
|
2794 |
mask: Optional[Tensor] = None, |
|
|
2795 |
replace: Bool = False, |
|
|
2796 |
sort: Optional[Literal["noise", "x0", "data", "x1"]] = "x0" |
|
|
2797 |
) -> Tuple[Tensor, Tensor, Optional[Tensor]] |
|
|
2798 |
``` |
|
|
2799 |
|
|
|
2800 |
Sample indices for noise and data in minibatch according to OT plan. |
|
|
2801 |
|
|
|
2802 |
Compute the OT plan $\pi$ (wrt squared Euclidean cost after Kabsch alignment) between a source and a target |
|
|
2803 |
minibatch and draw source and target samples from pi $(x,z) \sim \pi$. |
|
|
2804 |
|
|
|
2805 |
**Arguments**: |
|
|
2806 |
|
|
|
2807 |
- `x0` _Tensor_ - shape (bs, *dim), noise from source minibatch. |
|
|
2808 |
- `x1` _Tensor_ - shape (bs, *dim), data from source minibatch. |
|
|
2809 |
- `mask` _Optional[Tensor], optional_ - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None. |
|
|
2810 |
- `replace` _bool_ - sampling w/ or w/o replacement from the OT plan, default to False. |
|
|
2811 |
- `sort` _str_ - Optional Literal string to sort either x1 or x0 based on the input. |
|
|
2812 |
|
|
|
2813 |
|
|
|
2814 |
**Returns**: |
|
|
2815 |
|
|
|
2816 |
- `Tuple` - tuple of 2 tensors, represents the noise and data samples following OT plan pi. |
|
|
2817 |
|
|
|
2818 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationkabsch_augmentation"></a> |
|
|
2819 |
|
|
|
2820 |
# bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation.kabsch\_augmentation |
|
|
2821 |
|
|
|
2822 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationkabsch_augmentationKabschAugmentation"></a> |
|
|
2823 |
|
|
|
2824 |
## KabschAugmentation Objects |
|
|
2825 |
|
|
|
2826 |
```python |
|
|
2827 |
class KabschAugmentation() |
|
|
2828 |
``` |
|
|
2829 |
|
|
|
2830 |
Point-wise Kabsch alignment. |
|
|
2831 |
|
|
|
2832 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationkabsch_augmentationKabschAugmentation__init__"></a> |
|
|
2833 |
|
|
|
2834 |
#### \_\_init\_\_ |
|
|
2835 |
|
|
|
2836 |
```python |
|
|
2837 |
def __init__() |
|
|
2838 |
``` |
|
|
2839 |
|
|
|
2840 |
Initialize the KabschAugmentation instance. |
|
|
2841 |
|
|
|
2842 |
**Notes**: |
|
|
2843 |
|
|
|
2844 |
- This implementation assumes no required initialization arguments. |
|
|
2845 |
- You can add instance variables (e.g., `self.variable_name`) as needed. |
|
|
2846 |
|
|
|
2847 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationkabsch_augmentationKabschAugmentationkabsch_align"></a> |
|
|
2848 |
|
|
|
2849 |
#### kabsch\_align |
|
|
2850 |
|
|
|
2851 |
```python |
|
|
2852 |
def kabsch_align(target: Tensor, noise: Tensor) |
|
|
2853 |
``` |
|
|
2854 |
|
|
|
2855 |
Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise. |
|
|
2856 |
|
|
|
2857 |
**Arguments**: |
|
|
2858 |
|
|
|
2859 |
- `target` _Tensor_ - shape (N, *dim), data from source minibatch. |
|
|
2860 |
- `noise` _Tensor_ - shape (N, *dim), noise from source minibatch. |
|
|
2861 |
|
|
|
2862 |
|
|
|
2863 |
**Returns**: |
|
|
2864 |
|
|
|
2865 |
- `R` _Tensor_ - shape (*dim, *dim), the rotation matrix. |
|
|
2866 |
Aliged Target (Tensor): target tensor rotated and shifted to reduced RMSD with noise |
|
|
2867 |
|
|
|
2868 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationkabsch_augmentationKabschAugmentationbatch_kabsch_align"></a> |
|
|
2869 |
|
|
|
2870 |
#### batch\_kabsch\_align |
|
|
2871 |
|
|
|
2872 |
```python |
|
|
2873 |
def batch_kabsch_align(target: Tensor, noise: Tensor) |
|
|
2874 |
``` |
|
|
2875 |
|
|
|
2876 |
Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise. |
|
|
2877 |
|
|
|
2878 |
**Arguments**: |
|
|
2879 |
|
|
|
2880 |
- `target` _Tensor_ - shape (B, N, *dim), data from source minibatch. |
|
|
2881 |
- `noise` _Tensor_ - shape (B, N, *dim), noise from source minibatch. |
|
|
2882 |
|
|
|
2883 |
|
|
|
2884 |
**Returns**: |
|
|
2885 |
|
|
|
2886 |
- `R` _Tensor_ - shape (*dim, *dim), the rotation matrix. |
|
|
2887 |
Aliged Target (Tensor): target tensor rotated and shifted to reduced RMSD with noise |
|
|
2888 |
|
|
|
2889 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationkabsch_augmentationKabschAugmentationapply_augmentation"></a> |
|
|
2890 |
|
|
|
2891 |
#### apply\_augmentation |
|
|
2892 |
|
|
|
2893 |
```python |
|
|
2894 |
def apply_augmentation(x0: Tensor, |
|
|
2895 |
x1: Tensor, |
|
|
2896 |
mask: Optional[Tensor] = None, |
|
|
2897 |
align_noise_to_data=True) -> Tuple[Tensor, Tensor] |
|
|
2898 |
``` |
|
|
2899 |
|
|
|
2900 |
Sample indices for noise and data in minibatch according to OT plan. |
|
|
2901 |
|
|
|
2902 |
Compute the OT plan $\pi$ (wrt squared Euclidean cost after Kabsch alignment) between a source and a target |
|
|
2903 |
minibatch and draw source and target samples from pi $(x,z) \sim \pi$. |
|
|
2904 |
|
|
|
2905 |
**Arguments**: |
|
|
2906 |
|
|
|
2907 |
- `x0` _Tensor_ - shape (bs, *dim), noise from source minibatch. |
|
|
2908 |
- `x1` _Tensor_ - shape (bs, *dim), data from source minibatch. |
|
|
2909 |
- `mask` _Optional[Tensor], optional_ - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None. |
|
|
2910 |
- `replace` _bool_ - sampling w/ or w/o replacement from the OT plan, default to False. |
|
|
2911 |
- `align_noise_to_data` _bool_ - Direction of alignment default is True meaning it augments Noise to reduce error to Data. |
|
|
2912 |
|
|
|
2913 |
|
|
|
2914 |
**Returns**: |
|
|
2915 |
|
|
|
2916 |
- `Tuple` - tuple of 2 tensors, represents the noise and data samples following OT plan pi. |
|
|
2917 |
|
|
|
2918 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentation"></a> |
|
|
2919 |
|
|
|
2920 |
# bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation |
|
|
2921 |
|
|
|
2922 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationaugmentation_types"></a> |
|
|
2923 |
|
|
|
2924 |
# bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation.augmentation\_types |
|
|
2925 |
|
|
|
2926 |
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationaugmentation_typesAugmentationType"></a> |
|
|
2927 |
|
|
|
2928 |
## AugmentationType Objects |
|
|
2929 |
|
|
|
2930 |
```python |
|
|
2931 |
class AugmentationType(Enum) |
|
|
2932 |
``` |
|
|
2933 |
|
|
|
2934 |
An enumeration representing the type ofOptimal Transport that can be used in Continuous Flow Matching. |
|
|
2935 |
|
|
|
2936 |
- **EXACT_OT**: Standard mini batch optimal transport defined in https://arxiv.org/pdf/2302.00482. |
|
|
2937 |
- **EQUIVARIANT_OT**: Adding roto/translation optimization to mini batch OT see https://arxiv.org/pdf/2306.15030 https://arxiv.org/pdf/2312.07168 4.2. |
|
|
2938 |
- **KABSCH**: Simple Kabsch alignment between each data and noise point, No permuation # https://arxiv.org/pdf/2410.22388 Sec 3.2 |
|
|
2939 |
|
|
|
2940 |
These prediction types can be used to train neural networks for specific tasks, such as denoising, image synthesis, or time-series forecasting. |
|
|
2941 |
|
|
|
2942 |
<a id="mocointerpolantscontinuous_timecontinuous"></a> |
|
|
2943 |
|
|
|
2944 |
# bionemo.moco.interpolants.continuous\_time.continuous |
|
|
2945 |
|
|
|
2946 |
<a id="mocointerpolantscontinuous_timecontinuousvdm"></a> |
|
|
2947 |
|
|
|
2948 |
# bionemo.moco.interpolants.continuous\_time.continuous.vdm |
|
|
2949 |
|
|
|
2950 |
<a id="mocointerpolantscontinuous_timecontinuousvdmVDM"></a> |
|
|
2951 |
|
|
|
2952 |
## VDM Objects |
|
|
2953 |
|
|
|
2954 |
```python |
|
|
2955 |
class VDM(Interpolant) |
|
|
2956 |
``` |
|
|
2957 |
|
|
|
2958 |
A Variational Diffusion Models (VDM) interpolant. |
|
|
2959 |
|
|
|
2960 |
------- |
|
|
2961 |
|
|
|
2962 |
**Examples**: |
|
|
2963 |
|
|
|
2964 |
```python |
|
|
2965 |
>>> import torch |
|
|
2966 |
>>> from bionemo.bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior |
|
|
2967 |
>>> from bionemo.bionemo.moco.distributions.time.uniform import UniformTimeDistribution |
|
|
2968 |
>>> from bionemo.bionemo.moco.interpolants.discrete_time.continuous.vdm import VDM |
|
|
2969 |
>>> from bionemo.bionemo.moco.schedules.noise.continuous_snr_transforms import CosineSNRTransform |
|
|
2970 |
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule |
|
|
2971 |
|
|
|
2972 |
|
|
|
2973 |
vdm = VDM( |
|
|
2974 |
time_distribution = UniformTimeDistribution(...), |
|
|
2975 |
prior_distribution = GaussianPrior(...), |
|
|
2976 |
noise_schedule = CosineSNRTransform(...), |
|
|
2977 |
) |
|
|
2978 |
model = Model(...) |
|
|
2979 |
|
|
|
2980 |
# Training |
|
|
2981 |
for epoch in range(1000): |
|
|
2982 |
data = data_loader.get(...) |
|
|
2983 |
time = vdm.sample_time(batch_size) |
|
|
2984 |
noise = vdm.sample_prior(data.shape) |
|
|
2985 |
xt = vdm.interpolate(data, noise, time) |
|
|
2986 |
|
|
|
2987 |
x_pred = model(xt, time) |
|
|
2988 |
loss = vdm.loss(x_pred, data, time) |
|
|
2989 |
loss.backward() |
|
|
2990 |
|
|
|
2991 |
# Generation |
|
|
2992 |
x_pred = vdm.sample_prior(data.shape) |
|
|
2993 |
for t in LinearInferenceSchedule(...).generate_schedule(): |
|
|
2994 |
time = torch.full((batch_size,), t) |
|
|
2995 |
x_hat = model(x_pred, time) |
|
|
2996 |
x_pred = vdm.step(x_hat, time, x_pred) |
|
|
2997 |
return x_pred |
|
|
2998 |
|
|
|
2999 |
``` |
|
|
3000 |
|
|
|
3001 |
<a id="mocointerpolantscontinuous_timecontinuousvdmVDM__init__"></a> |
|
|
3002 |
|
|
|
3003 |
#### \_\_init\_\_ |
|
|
3004 |
|
|
|
3005 |
```python |
|
|
3006 |
def __init__(time_distribution: TimeDistribution, |
|
|
3007 |
prior_distribution: PriorDistribution, |
|
|
3008 |
noise_schedule: ContinuousSNRTransform, |
|
|
3009 |
prediction_type: Union[PredictionType, str] = PredictionType.DATA, |
|
|
3010 |
device: Union[str, torch.device] = "cpu", |
|
|
3011 |
rng_generator: Optional[torch.Generator] = None) |
|
|
3012 |
``` |
|
|
3013 |
|
|
|
3014 |
Initializes the DDPM interpolant. |
|
|
3015 |
|
|
|
3016 |
**Arguments**: |
|
|
3017 |
|
|
|
3018 |
- `time_distribution` _TimeDistribution_ - The distribution of time steps, used to sample time points for the diffusion process. |
|
|
3019 |
- `prior_distribution` _PriorDistribution_ - The prior distribution of the variable, used as the starting point for the diffusion process. |
|
|
3020 |
- `noise_schedule` _ContinuousSNRTransform_ - The schedule of noise, defining the amount of noise added at each time step. |
|
|
3021 |
- `prediction_type` _PredictionType, optional_ - The type of prediction, either "data" or another type. Defaults to "data". |
|
|
3022 |
- `device` _str, optional_ - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu". |
|
|
3023 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
3024 |
|
|
|
3025 |
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMinterpolate"></a> |
|
|
3026 |
|
|
|
3027 |
#### interpolate |
|
|
3028 |
|
|
|
3029 |
```python |
|
|
3030 |
def interpolate(data: Tensor, t: Tensor, noise: Tensor) |
|
|
3031 |
``` |
|
|
3032 |
|
|
|
3033 |
Get x(t) with given time t from noise and data. |
|
|
3034 |
|
|
|
3035 |
**Arguments**: |
|
|
3036 |
|
|
|
3037 |
- `data` _Tensor_ - target |
|
|
3038 |
- `t` _Tensor_ - time |
|
|
3039 |
- `noise` _Tensor_ - noise from prior() |
|
|
3040 |
|
|
|
3041 |
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMforward_process"></a> |
|
|
3042 |
|
|
|
3043 |
#### forward\_process |
|
|
3044 |
|
|
|
3045 |
```python |
|
|
3046 |
def forward_process(data: Tensor, t: Tensor, noise: Optional[Tensor] = None) |
|
|
3047 |
``` |
|
|
3048 |
|
|
|
3049 |
Get x(t) with given time t from noise and data. |
|
|
3050 |
|
|
|
3051 |
**Arguments**: |
|
|
3052 |
|
|
|
3053 |
- `data` _Tensor_ - target |
|
|
3054 |
- `t` _Tensor_ - time |
|
|
3055 |
- `noise` _Tensor, optional_ - noise from prior(). Defaults to None |
|
|
3056 |
|
|
|
3057 |
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMprocess_data_prediction"></a> |
|
|
3058 |
|
|
|
3059 |
#### process\_data\_prediction |
|
|
3060 |
|
|
|
3061 |
```python |
|
|
3062 |
def process_data_prediction(model_output: Tensor, sample, t) |
|
|
3063 |
``` |
|
|
3064 |
|
|
|
3065 |
Converts the model output to a data prediction based on the prediction type. |
|
|
3066 |
|
|
|
3067 |
This conversion stems from the Progressive Distillation for Fast Sampling of Diffusion Models https://arxiv.org/pdf/2202.00512. |
|
|
3068 |
Given the model output and the sample, we convert the output to a data prediction based on the prediction type. |
|
|
3069 |
The conversion formulas are as follows: |
|
|
3070 |
- For "noise" prediction type: `pred_data = (sample - noise_scale * model_output) / data_scale` |
|
|
3071 |
- For "data" prediction type: `pred_data = model_output` |
|
|
3072 |
- For "v_prediction" prediction type: `pred_data = data_scale * sample - noise_scale * model_output` |
|
|
3073 |
|
|
|
3074 |
**Arguments**: |
|
|
3075 |
|
|
|
3076 |
- `model_output` _Tensor_ - The output of the model. |
|
|
3077 |
- `sample` _Tensor_ - The input sample. |
|
|
3078 |
- `t` _Tensor_ - The time step. |
|
|
3079 |
|
|
|
3080 |
|
|
|
3081 |
**Returns**: |
|
|
3082 |
|
|
|
3083 |
The data prediction based on the prediction type. |
|
|
3084 |
|
|
|
3085 |
|
|
|
3086 |
**Raises**: |
|
|
3087 |
|
|
|
3088 |
- `ValueError` - If the prediction type is not one of "noise", "data", or "v_prediction". |
|
|
3089 |
|
|
|
3090 |
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMprocess_noise_prediction"></a> |
|
|
3091 |
|
|
|
3092 |
#### process\_noise\_prediction |
|
|
3093 |
|
|
|
3094 |
```python |
|
|
3095 |
def process_noise_prediction(model_output: Tensor, sample: Tensor, t: Tensor) |
|
|
3096 |
``` |
|
|
3097 |
|
|
|
3098 |
Do the same as process_data_prediction but take the model output and convert to nosie. |
|
|
3099 |
|
|
|
3100 |
**Arguments**: |
|
|
3101 |
|
|
|
3102 |
- `model_output` _Tensor_ - The output of the model. |
|
|
3103 |
- `sample` _Tensor_ - The input sample. |
|
|
3104 |
- `t` _Tensor_ - The time step. |
|
|
3105 |
|
|
|
3106 |
|
|
|
3107 |
**Returns**: |
|
|
3108 |
|
|
|
3109 |
The input as noise if the prediction type is "noise". |
|
|
3110 |
|
|
|
3111 |
|
|
|
3112 |
**Raises**: |
|
|
3113 |
|
|
|
3114 |
- `ValueError` - If the prediction type is not "noise". |
|
|
3115 |
|
|
|
3116 |
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMstep"></a> |
|
|
3117 |
|
|
|
3118 |
#### step |
|
|
3119 |
|
|
|
3120 |
```python |
|
|
3121 |
def step(model_out: Tensor, |
|
|
3122 |
t: Tensor, |
|
|
3123 |
xt: Tensor, |
|
|
3124 |
dt: Tensor, |
|
|
3125 |
mask: Optional[Tensor] = None, |
|
|
3126 |
center: Bool = False, |
|
|
3127 |
temperature: Float = 1.0) |
|
|
3128 |
``` |
|
|
3129 |
|
|
|
3130 |
Do one step integration. |
|
|
3131 |
|
|
|
3132 |
**Arguments**: |
|
|
3133 |
|
|
|
3134 |
- `model_out` _Tensor_ - The output of the model. |
|
|
3135 |
- `xt` _Tensor_ - The current data point. |
|
|
3136 |
- `t` _Tensor_ - The current time step. |
|
|
3137 |
- `dt` _Tensor_ - The time step increment. |
|
|
3138 |
- `mask` _Optional[Tensor], optional_ - An optional mask to apply to the data. Defaults to None. |
|
|
3139 |
- `center` _bool_ - Whether to center the data. Defaults to False. |
|
|
3140 |
- `temperature` _Float_ - The temperature parameter for low temperature sampling. Defaults to 1.0. |
|
|
3141 |
|
|
|
3142 |
|
|
|
3143 |
**Notes**: |
|
|
3144 |
|
|
|
3145 |
The temperature parameter controls the trade off between diversity and sample quality. |
|
|
3146 |
Decreasing the temperature sharpens the sampling distribtion to focus on more likely samples. |
|
|
3147 |
The impact of low temperature sampling must be ablated analytically. |
|
|
3148 |
|
|
|
3149 |
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMscore"></a> |
|
|
3150 |
|
|
|
3151 |
#### score |
|
|
3152 |
|
|
|
3153 |
```python |
|
|
3154 |
def score(x_hat: Tensor, xt: Tensor, t: Tensor) |
|
|
3155 |
``` |
|
|
3156 |
|
|
|
3157 |
Converts the data prediction to the estimated score function. |
|
|
3158 |
|
|
|
3159 |
**Arguments**: |
|
|
3160 |
|
|
|
3161 |
- `x_hat` _tensor_ - The predicted data point. |
|
|
3162 |
- `xt` _Tensor_ - The current data point. |
|
|
3163 |
- `t` _Tensor_ - The time step. |
|
|
3164 |
|
|
|
3165 |
|
|
|
3166 |
**Returns**: |
|
|
3167 |
|
|
|
3168 |
The estimated score function. |
|
|
3169 |
|
|
|
3170 |
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMstep_ddim"></a> |
|
|
3171 |
|
|
|
3172 |
#### step\_ddim |
|
|
3173 |
|
|
|
3174 |
```python |
|
|
3175 |
def step_ddim(model_out: Tensor, |
|
|
3176 |
t: Tensor, |
|
|
3177 |
xt: Tensor, |
|
|
3178 |
dt: Tensor, |
|
|
3179 |
mask: Optional[Tensor] = None, |
|
|
3180 |
eta: Float = 0.0, |
|
|
3181 |
center: Bool = False) |
|
|
3182 |
``` |
|
|
3183 |
|
|
|
3184 |
Do one step of DDIM sampling. |
|
|
3185 |
|
|
|
3186 |
From the ddpm equations alpha_bar = alpha**2 and 1 - alpha**2 = sigma**2 |
|
|
3187 |
|
|
|
3188 |
**Arguments**: |
|
|
3189 |
|
|
|
3190 |
- `model_out` _Tensor_ - output of the model |
|
|
3191 |
- `t` _Tensor_ - current time step |
|
|
3192 |
- `xt` _Tensor_ - current data point |
|
|
3193 |
- `dt` _Tensor_ - The time step increment. |
|
|
3194 |
- `mask` _Optional[Tensor], optional_ - mask for the data point. Defaults to None. |
|
|
3195 |
- `eta` _Float, optional_ - DDIM sampling parameter. Defaults to 0.0. |
|
|
3196 |
- `center` _Bool, optional_ - whether to center the data point. Defaults to False. |
|
|
3197 |
|
|
|
3198 |
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMset_loss_weight_fn"></a> |
|
|
3199 |
|
|
|
3200 |
#### set\_loss\_weight\_fn |
|
|
3201 |
|
|
|
3202 |
```python |
|
|
3203 |
def set_loss_weight_fn(fn: Callable) |
|
|
3204 |
``` |
|
|
3205 |
|
|
|
3206 |
Sets the loss_weight attribute of the instance to the given function. |
|
|
3207 |
|
|
|
3208 |
**Arguments**: |
|
|
3209 |
|
|
|
3210 |
- `fn` - The function to set as the loss_weight attribute. This function should take three arguments: raw_loss, t, and weight_type. |
|
|
3211 |
|
|
|
3212 |
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMloss_weight"></a> |
|
|
3213 |
|
|
|
3214 |
#### loss\_weight |
|
|
3215 |
|
|
|
3216 |
```python |
|
|
3217 |
def loss_weight(raw_loss: Tensor, |
|
|
3218 |
t: Tensor, |
|
|
3219 |
weight_type: str, |
|
|
3220 |
dt: Float = 0.001) -> Tensor |
|
|
3221 |
``` |
|
|
3222 |
|
|
|
3223 |
Calculates the weight for the loss based on the given weight type. |
|
|
3224 |
|
|
|
3225 |
This function computes the loss weight according to the specified `weight_type`. |
|
|
3226 |
The available weight types are: |
|
|
3227 |
- "ones": uniform weight of 1.0 |
|
|
3228 |
- "data_to_noise": derived from Equation (9) of https://arxiv.org/pdf/2202.00512 |
|
|
3229 |
- "variational_objective_discrete": based on the variational objective, see https://arxiv.org/pdf/2202.00512 |
|
|
3230 |
|
|
|
3231 |
**Arguments**: |
|
|
3232 |
|
|
|
3233 |
- `raw_loss` _Tensor_ - The raw loss calculated from the model prediction and target. |
|
|
3234 |
- `t` _Tensor_ - The time step. |
|
|
3235 |
- `weight_type` _str_ - The type of weight to use. Can be "ones", "data_to_noise", or "variational_objective_discrete". |
|
|
3236 |
- `dt` _Float, optional_ - The time step increment. Defaults to 0.001. |
|
|
3237 |
|
|
|
3238 |
|
|
|
3239 |
**Returns**: |
|
|
3240 |
|
|
|
3241 |
- `Tensor` - The weight for the loss. |
|
|
3242 |
|
|
|
3243 |
|
|
|
3244 |
**Raises**: |
|
|
3245 |
|
|
|
3246 |
- `ValueError` - If the weight type is not recognized. |
|
|
3247 |
|
|
|
3248 |
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMloss"></a> |
|
|
3249 |
|
|
|
3250 |
#### loss |
|
|
3251 |
|
|
|
3252 |
```python |
|
|
3253 |
def loss(model_pred: Tensor, |
|
|
3254 |
target: Tensor, |
|
|
3255 |
t: Tensor, |
|
|
3256 |
dt: Optional[Float] = 0.001, |
|
|
3257 |
mask: Optional[Tensor] = None, |
|
|
3258 |
weight_type: str = "ones") |
|
|
3259 |
``` |
|
|
3260 |
|
|
|
3261 |
Calculates the loss given the model prediction, target, and time. |
|
|
3262 |
|
|
|
3263 |
**Arguments**: |
|
|
3264 |
|
|
|
3265 |
- `model_pred` _Tensor_ - The predicted output from the model. |
|
|
3266 |
- `target` _Tensor_ - The target output for the model prediction. |
|
|
3267 |
- `t` _Tensor_ - The time at which the loss is calculated. |
|
|
3268 |
- `dt` _Optional[Float], optional_ - The time step increment. Defaults to 0.001. |
|
|
3269 |
- `mask` _Optional[Tensor], optional_ - The mask for the data point. Defaults to None. |
|
|
3270 |
- `weight_type` _str, optional_ - The type of weight to use for the loss. Can be "ones", "data_to_noise", or "variational_objective". Defaults to "ones". |
|
|
3271 |
|
|
|
3272 |
|
|
|
3273 |
**Returns**: |
|
|
3274 |
|
|
|
3275 |
- `Tensor` - The calculated loss batch tensor. |
|
|
3276 |
|
|
|
3277 |
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMstep_hybrid_sde"></a> |
|
|
3278 |
|
|
|
3279 |
#### step\_hybrid\_sde |
|
|
3280 |
|
|
|
3281 |
```python |
|
|
3282 |
def step_hybrid_sde(model_out: Tensor, |
|
|
3283 |
t: Tensor, |
|
|
3284 |
xt: Tensor, |
|
|
3285 |
dt: Tensor, |
|
|
3286 |
mask: Optional[Tensor] = None, |
|
|
3287 |
center: Bool = False, |
|
|
3288 |
temperature: Float = 1.0, |
|
|
3289 |
equilibrium_rate: Float = 0.0) -> Tensor |
|
|
3290 |
``` |
|
|
3291 |
|
|
|
3292 |
Do one step integration of Hybrid Langevin-Reverse Time SDE. |
|
|
3293 |
|
|
|
3294 |
See section B.3 page 37 https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf. |
|
|
3295 |
and https://github.com/generatebio/chroma/blob/929407c605013613941803c6113adefdccaad679/chroma/layers/structure/diffusion.py#L730 |
|
|
3296 |
|
|
|
3297 |
**Arguments**: |
|
|
3298 |
|
|
|
3299 |
- `model_out` _Tensor_ - The output of the model. |
|
|
3300 |
- `xt` _Tensor_ - The current data point. |
|
|
3301 |
- `t` _Tensor_ - The current time step. |
|
|
3302 |
- `dt` _Tensor_ - The time step increment. |
|
|
3303 |
- `mask` _Optional[Tensor], optional_ - An optional mask to apply to the data. Defaults to None. |
|
|
3304 |
- `center` _bool, optional_ - Whether to center the data. Defaults to False. |
|
|
3305 |
- `temperature` _Float, optional_ - The temperature parameter for low temperature sampling. Defaults to 1.0. |
|
|
3306 |
- `equilibrium_rate` _Float, optional_ - The rate of Langevin equilibration. Scales the amount of Langevin dynamics per unit time. Best values are in the range [1.0, 5.0]. Defaults to 0.0. |
|
|
3307 |
|
|
|
3308 |
|
|
|
3309 |
**Notes**: |
|
|
3310 |
|
|
|
3311 |
For all step functions that use the SDE formulation its important to note that we are moving backwards in time which corresponds to an apparent sign change. |
|
|
3312 |
A clear example can be seen in slide 29 https://ernestryu.com/courses/FM/diffusion1.pdf. |
|
|
3313 |
|
|
|
3314 |
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMstep_ode"></a> |
|
|
3315 |
|
|
|
3316 |
#### step\_ode |
|
|
3317 |
|
|
|
3318 |
```python |
|
|
3319 |
def step_ode(model_out: Tensor, |
|
|
3320 |
t: Tensor, |
|
|
3321 |
xt: Tensor, |
|
|
3322 |
dt: Tensor, |
|
|
3323 |
mask: Optional[Tensor] = None, |
|
|
3324 |
center: Bool = False, |
|
|
3325 |
temperature: Float = 1.0) -> Tensor |
|
|
3326 |
``` |
|
|
3327 |
|
|
|
3328 |
Do one step integration of ODE. |
|
|
3329 |
|
|
|
3330 |
See section B page 36 https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf. |
|
|
3331 |
and https://github.com/generatebio/chroma/blob/929407c605013613941803c6113adefdccaad679/chroma/layers/structure/diffusion.py#L730 |
|
|
3332 |
|
|
|
3333 |
**Arguments**: |
|
|
3334 |
|
|
|
3335 |
- `model_out` _Tensor_ - The output of the model. |
|
|
3336 |
- `xt` _Tensor_ - The current data point. |
|
|
3337 |
- `t` _Tensor_ - The current time step. |
|
|
3338 |
- `dt` _Tensor_ - The time step increment. |
|
|
3339 |
- `mask` _Optional[Tensor], optional_ - An optional mask to apply to the data. Defaults to None. |
|
|
3340 |
- `center` _bool, optional_ - Whether to center the data. Defaults to False. |
|
|
3341 |
- `temperature` _Float, optional_ - The temperature parameter for low temperature sampling. Defaults to 1.0. |
|
|
3342 |
|
|
|
3343 |
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matching"></a> |
|
|
3344 |
|
|
|
3345 |
# bionemo.moco.interpolants.continuous\_time.continuous.continuous\_flow\_matching |
|
|
3346 |
|
|
|
3347 |
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcher"></a> |
|
|
3348 |
|
|
|
3349 |
## ContinuousFlowMatcher Objects |
|
|
3350 |
|
|
|
3351 |
```python |
|
|
3352 |
class ContinuousFlowMatcher(Interpolant) |
|
|
3353 |
``` |
|
|
3354 |
|
|
|
3355 |
A Continuous Flow Matching interpolant. |
|
|
3356 |
|
|
|
3357 |
------- |
|
|
3358 |
|
|
|
3359 |
**Examples**: |
|
|
3360 |
|
|
|
3361 |
```python |
|
|
3362 |
>>> import torch |
|
|
3363 |
>>> from bionemo.bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior |
|
|
3364 |
>>> from bionemo.bionemo.moco.distributions.time.uniform import UniformTimeDistribution |
|
|
3365 |
>>> from bionemo.bionemo.moco.interpolants.continuous_time.continuous.continuous_flow_matching import ContinuousFlowMatcher |
|
|
3366 |
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule |
|
|
3367 |
|
|
|
3368 |
flow_matcher = ContinuousFlowMatcher( |
|
|
3369 |
time_distribution = UniformTimeDistribution(...), |
|
|
3370 |
prior_distribution = GaussianPrior(...), |
|
|
3371 |
) |
|
|
3372 |
model = Model(...) |
|
|
3373 |
|
|
|
3374 |
# Training |
|
|
3375 |
for epoch in range(1000): |
|
|
3376 |
data = data_loader.get(...) |
|
|
3377 |
time = flow_matcher.sample_time(batch_size) |
|
|
3378 |
noise = flow_matcher.sample_prior(data.shape) |
|
|
3379 |
data, time, noise = flow_matcher.apply_augmentation(noise, data) # Optional, only for OT |
|
|
3380 |
xt = flow_matcher.interpolate(data, time, noise) |
|
|
3381 |
flow = flow_matcher.calculate_target(data, noise) |
|
|
3382 |
|
|
|
3383 |
u_pred = model(xt, time) |
|
|
3384 |
loss = flow_matcher.loss(u_pred, flow) |
|
|
3385 |
loss.backward() |
|
|
3386 |
|
|
|
3387 |
# Generation |
|
|
3388 |
x_pred = flow_matcher.sample_prior(data.shape) |
|
|
3389 |
inference_sched = LinearInferenceSchedule(...) |
|
|
3390 |
for t in inference_sched.generate_schedule(): |
|
|
3391 |
time = inference_sched.pad_time(x_pred.shape[0], t) |
|
|
3392 |
u_hat = model(x_pred, time) |
|
|
3393 |
x_pred = flow_matcher.step(u_hat, x_pred, time) |
|
|
3394 |
return x_pred |
|
|
3395 |
|
|
|
3396 |
``` |
|
|
3397 |
|
|
|
3398 |
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcher__init__"></a> |
|
|
3399 |
|
|
|
3400 |
#### \_\_init\_\_ |
|
|
3401 |
|
|
|
3402 |
```python |
|
|
3403 |
def __init__(time_distribution: TimeDistribution, |
|
|
3404 |
prior_distribution: PriorDistribution, |
|
|
3405 |
prediction_type: Union[PredictionType, str] = PredictionType.DATA, |
|
|
3406 |
sigma: Float = 0, |
|
|
3407 |
augmentation_type: Optional[Union[AugmentationType, str]] = None, |
|
|
3408 |
augmentation_num_threads: int = 1, |
|
|
3409 |
data_scale: Float = 1.0, |
|
|
3410 |
device: Union[str, torch.device] = "cpu", |
|
|
3411 |
rng_generator: Optional[torch.Generator] = None, |
|
|
3412 |
eps: Float = 1e-5) |
|
|
3413 |
``` |
|
|
3414 |
|
|
|
3415 |
Initializes the Continuous Flow Matching interpolant. |
|
|
3416 |
|
|
|
3417 |
**Arguments**: |
|
|
3418 |
|
|
|
3419 |
- `time_distribution` _TimeDistribution_ - The distribution of time steps, used to sample time points for the diffusion process. |
|
|
3420 |
- `prior_distribution` _PriorDistribution_ - The prior distribution of the variable, used as the starting point for the diffusion process. |
|
|
3421 |
- `prediction_type` _PredictionType, optional_ - The type of prediction, either "flow" or another type. Defaults to PredictionType.DATA. |
|
|
3422 |
- `sigma` _Float, optional_ - The standard deviation of the Gaussian noise added to the interpolated data. Defaults to 0. |
|
|
3423 |
- `augmentation_type` _Optional[Union[AugmentationType, str]], optional_ - The type of optimal transport, if applicable. Defaults to None. |
|
|
3424 |
- `augmentation_num_threads` - Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1. |
|
|
3425 |
- `data_scale` _Float, optional_ - The scale factor for the data. Defaults to 1.0. |
|
|
3426 |
- `device` _Union[str, torch.device], optional_ - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu". |
|
|
3427 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
3428 |
- `eps` - Small float to prevent divide by zero |
|
|
3429 |
|
|
|
3430 |
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherapply_augmentation"></a> |
|
|
3431 |
|
|
|
3432 |
#### apply\_augmentation |
|
|
3433 |
|
|
|
3434 |
```python |
|
|
3435 |
def apply_augmentation(x0: Tensor, |
|
|
3436 |
x1: Tensor, |
|
|
3437 |
mask: Optional[Tensor] = None, |
|
|
3438 |
**kwargs) -> tuple |
|
|
3439 |
``` |
|
|
3440 |
|
|
|
3441 |
Sample and apply the optimal transport plan between batched (and masked) x0 and x1. |
|
|
3442 |
|
|
|
3443 |
**Arguments**: |
|
|
3444 |
|
|
|
3445 |
- `x0` _Tensor_ - shape (bs, *dim), noise from source minibatch. |
|
|
3446 |
- `x1` _Tensor_ - shape (bs, *dim), data from source minibatch. |
|
|
3447 |
- `mask` _Optional[Tensor], optional_ - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None. |
|
|
3448 |
- `**kwargs` - Additional keyword arguments to be passed to self.augmentation_sampler.apply_augmentation or handled within this method. |
|
|
3449 |
|
|
|
3450 |
|
|
|
3451 |
|
|
|
3452 |
**Returns**: |
|
|
3453 |
|
|
|
3454 |
- `Tuple` - tuple of 2 tensors, represents the noise and data samples following OT plan pi. |
|
|
3455 |
|
|
|
3456 |
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherundo_scale_data"></a> |
|
|
3457 |
|
|
|
3458 |
#### undo\_scale\_data |
|
|
3459 |
|
|
|
3460 |
```python |
|
|
3461 |
def undo_scale_data(data: Tensor) -> Tensor |
|
|
3462 |
``` |
|
|
3463 |
|
|
|
3464 |
Downscale the input data by the data scale factor. |
|
|
3465 |
|
|
|
3466 |
**Arguments**: |
|
|
3467 |
|
|
|
3468 |
- `data` _Tensor_ - The input data to downscale. |
|
|
3469 |
|
|
|
3470 |
|
|
|
3471 |
**Returns**: |
|
|
3472 |
|
|
|
3473 |
The downscaled data. |
|
|
3474 |
|
|
|
3475 |
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherscale_data"></a> |
|
|
3476 |
|
|
|
3477 |
#### scale\_data |
|
|
3478 |
|
|
|
3479 |
```python |
|
|
3480 |
def scale_data(data: Tensor) -> Tensor |
|
|
3481 |
``` |
|
|
3482 |
|
|
|
3483 |
Upscale the input data by the data scale factor. |
|
|
3484 |
|
|
|
3485 |
**Arguments**: |
|
|
3486 |
|
|
|
3487 |
- `data` _Tensor_ - The input data to upscale. |
|
|
3488 |
|
|
|
3489 |
|
|
|
3490 |
**Returns**: |
|
|
3491 |
|
|
|
3492 |
The upscaled data. |
|
|
3493 |
|
|
|
3494 |
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherinterpolate"></a> |
|
|
3495 |
|
|
|
3496 |
#### interpolate |
|
|
3497 |
|
|
|
3498 |
```python |
|
|
3499 |
def interpolate(data: Tensor, t: Tensor, noise: Tensor) -> Tensor |
|
|
3500 |
``` |
|
|
3501 |
|
|
|
3502 |
Get x_t with given time t from noise (x_0) and data (x_1). |
|
|
3503 |
|
|
|
3504 |
Currently, we use the linear interpolation as defined in: |
|
|
3505 |
1. Rectified flow: https://arxiv.org/abs/2209.03003. |
|
|
3506 |
2. Conditional flow matching: https://arxiv.org/abs/2210.02747 (called conditional optimal transport). |
|
|
3507 |
|
|
|
3508 |
**Arguments**: |
|
|
3509 |
|
|
|
3510 |
- `noise` _Tensor_ - noise from prior(), shape (batchsize, nodes, features) |
|
|
3511 |
- `t` _Tensor_ - time, shape (batchsize) |
|
|
3512 |
- `data` _Tensor_ - target, shape (batchsize, nodes, features) |
|
|
3513 |
|
|
|
3514 |
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatchercalculate_target"></a> |
|
|
3515 |
|
|
|
3516 |
#### calculate\_target |
|
|
3517 |
|
|
|
3518 |
```python |
|
|
3519 |
def calculate_target(data: Tensor, |
|
|
3520 |
noise: Tensor, |
|
|
3521 |
mask: Optional[Tensor] = None) -> Tensor |
|
|
3522 |
``` |
|
|
3523 |
|
|
|
3524 |
Get the target vector field at time t. |
|
|
3525 |
|
|
|
3526 |
**Arguments**: |
|
|
3527 |
|
|
|
3528 |
- `noise` _Tensor_ - noise from prior(), shape (batchsize, nodes, features) |
|
|
3529 |
- `data` _Tensor_ - target, shape (batchsize, nodes, features) |
|
|
3530 |
- `mask` _Optional[Tensor], optional_ - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None. |
|
|
3531 |
|
|
|
3532 |
|
|
|
3533 |
**Returns**: |
|
|
3534 |
|
|
|
3535 |
- `Tensor` - The target vector field at time t. |
|
|
3536 |
|
|
|
3537 |
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherprocess_vector_field_prediction"></a> |
|
|
3538 |
|
|
|
3539 |
#### process\_vector\_field\_prediction |
|
|
3540 |
|
|
|
3541 |
```python |
|
|
3542 |
def process_vector_field_prediction(model_output: Tensor, |
|
|
3543 |
xt: Optional[Tensor] = None, |
|
|
3544 |
t: Optional[Tensor] = None, |
|
|
3545 |
mask: Optional[Tensor] = None) |
|
|
3546 |
``` |
|
|
3547 |
|
|
|
3548 |
Process the model output based on the prediction type to calculate vecotr field. |
|
|
3549 |
|
|
|
3550 |
**Arguments**: |
|
|
3551 |
|
|
|
3552 |
- `model_output` _Tensor_ - The output of the model. |
|
|
3553 |
- `xt` _Tensor_ - The input sample. |
|
|
3554 |
- `t` _Tensor_ - The time step. |
|
|
3555 |
- `mask` _Optional[Tensor], optional_ - An optional mask to apply to the model output. Defaults to None. |
|
|
3556 |
|
|
|
3557 |
|
|
|
3558 |
**Returns**: |
|
|
3559 |
|
|
|
3560 |
The vector field prediction based on the prediction type. |
|
|
3561 |
|
|
|
3562 |
|
|
|
3563 |
**Raises**: |
|
|
3564 |
|
|
|
3565 |
- `ValueError` - If the prediction type is not "flow" or "data". |
|
|
3566 |
|
|
|
3567 |
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherprocess_data_prediction"></a> |
|
|
3568 |
|
|
|
3569 |
#### process\_data\_prediction |
|
|
3570 |
|
|
|
3571 |
```python |
|
|
3572 |
def process_data_prediction(model_output: Tensor, |
|
|
3573 |
xt: Optional[Tensor] = None, |
|
|
3574 |
t: Optional[Tensor] = None, |
|
|
3575 |
mask: Optional[Tensor] = None) |
|
|
3576 |
``` |
|
|
3577 |
|
|
|
3578 |
Process the model output based on the prediction type to generate clean data. |
|
|
3579 |
|
|
|
3580 |
**Arguments**: |
|
|
3581 |
|
|
|
3582 |
- `model_output` _Tensor_ - The output of the model. |
|
|
3583 |
- `xt` _Tensor_ - The input sample. |
|
|
3584 |
- `t` _Tensor_ - The time step. |
|
|
3585 |
- `mask` _Optional[Tensor], optional_ - An optional mask to apply to the model output. Defaults to None. |
|
|
3586 |
|
|
|
3587 |
|
|
|
3588 |
**Returns**: |
|
|
3589 |
|
|
|
3590 |
The data prediction based on the prediction type. |
|
|
3591 |
|
|
|
3592 |
|
|
|
3593 |
**Raises**: |
|
|
3594 |
|
|
|
3595 |
- `ValueError` - If the prediction type is not "flow". |
|
|
3596 |
|
|
|
3597 |
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherstep"></a> |
|
|
3598 |
|
|
|
3599 |
#### step |
|
|
3600 |
|
|
|
3601 |
```python |
|
|
3602 |
def step(model_out: Tensor, |
|
|
3603 |
xt: Tensor, |
|
|
3604 |
dt: Tensor, |
|
|
3605 |
t: Optional[Tensor] = None, |
|
|
3606 |
mask: Optional[Tensor] = None, |
|
|
3607 |
center: Bool = False) |
|
|
3608 |
``` |
|
|
3609 |
|
|
|
3610 |
Perform a single ODE step integration using Euler method. |
|
|
3611 |
|
|
|
3612 |
**Arguments**: |
|
|
3613 |
|
|
|
3614 |
- `model_out` _Tensor_ - The output of the model at the current time step. |
|
|
3615 |
- `xt` _Tensor_ - The current intermediate state. |
|
|
3616 |
- `dt` _Tensor_ - The time step size. |
|
|
3617 |
- `t` _Tensor, optional_ - The current time. Defaults to None. |
|
|
3618 |
- `mask` _Optional[Tensor], optional_ - A mask to apply to the model output. Defaults to None. |
|
|
3619 |
- `center` _Bool, optional_ - Whether to center the output. Defaults to False. |
|
|
3620 |
|
|
|
3621 |
|
|
|
3622 |
**Returns**: |
|
|
3623 |
|
|
|
3624 |
- `x_next` _Tensor_ - The updated state of the system after the single step, x_(t+dt). |
|
|
3625 |
|
|
|
3626 |
|
|
|
3627 |
**Notes**: |
|
|
3628 |
|
|
|
3629 |
- If a mask is provided, it is applied element-wise to the model output before scaling. |
|
|
3630 |
- The `clean` method is called on the updated state before it is returned. |
|
|
3631 |
|
|
|
3632 |
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherstep_score_stochastic"></a> |
|
|
3633 |
|
|
|
3634 |
#### step\_score\_stochastic |
|
|
3635 |
|
|
|
3636 |
```python |
|
|
3637 |
def step_score_stochastic(model_out: Tensor, |
|
|
3638 |
xt: Tensor, |
|
|
3639 |
dt: Tensor, |
|
|
3640 |
t: Tensor, |
|
|
3641 |
mask: Optional[Tensor] = None, |
|
|
3642 |
gt_mode: str = "tan", |
|
|
3643 |
gt_p: Float = 1.0, |
|
|
3644 |
gt_clamp: Optional[Float] = None, |
|
|
3645 |
score_temperature: Float = 1.0, |
|
|
3646 |
noise_temperature: Float = 1.0, |
|
|
3647 |
t_lim_ode: Float = 0.99, |
|
|
3648 |
center: Bool = False) |
|
|
3649 |
``` |
|
|
3650 |
|
|
|
3651 |
Perform a single SDE step integration using a score-based Langevin update. |
|
|
3652 |
|
|
|
3653 |
d x_t = [v(x_t, t) + g(t) * s(x_t, t) * score_temperature] dt + \sqrt{2 * g(t) * noise_temperature} dw_t. |
|
|
3654 |
|
|
|
3655 |
**Arguments**: |
|
|
3656 |
|
|
|
3657 |
- `model_out` _Tensor_ - The output of the model at the current time step. |
|
|
3658 |
- `xt` _Tensor_ - The current intermediate state. |
|
|
3659 |
- `dt` _Tensor_ - The time step size. |
|
|
3660 |
- `t` _Tensor, optional_ - The current time. Defaults to None. |
|
|
3661 |
- `mask` _Optional[Tensor], optional_ - A mask to apply to the model output. Defaults to None. |
|
|
3662 |
- `gt_mode` _str, optional_ - The mode for the gt function. Defaults to "tan". |
|
|
3663 |
- `gt_p` _Float, optional_ - The parameter for the gt function. Defaults to 1.0. |
|
|
3664 |
- `gt_clamp` - (Float, optional): Upper limit of gt term. Defaults to None. |
|
|
3665 |
- `score_temperature` _Float, optional_ - The temperature for the score part of the step. Defaults to 1.0. |
|
|
3666 |
- `noise_temperature` _Float, optional_ - The temperature for the stochastic part of the step. Defaults to 1.0. |
|
|
3667 |
- `t_lim_ode` _Float, optional_ - The time limit for the ODE step. Defaults to 0.99. |
|
|
3668 |
- `center` _Bool, optional_ - Whether to center the output. Defaults to False. |
|
|
3669 |
|
|
|
3670 |
|
|
|
3671 |
**Returns**: |
|
|
3672 |
|
|
|
3673 |
- `x_next` _Tensor_ - The updated state of the system after the single step, x_(t+dt). |
|
|
3674 |
|
|
|
3675 |
|
|
|
3676 |
**Notes**: |
|
|
3677 |
|
|
|
3678 |
- If a mask is provided, it is applied element-wise to the model output before scaling. |
|
|
3679 |
- The `clean` method is called on the updated state before it is returned. |
|
|
3680 |
|
|
|
3681 |
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherloss"></a> |
|
|
3682 |
|
|
|
3683 |
#### loss |
|
|
3684 |
|
|
|
3685 |
```python |
|
|
3686 |
def loss(model_pred: Tensor, |
|
|
3687 |
target: Tensor, |
|
|
3688 |
t: Optional[Tensor] = None, |
|
|
3689 |
xt: Optional[Tensor] = None, |
|
|
3690 |
mask: Optional[Tensor] = None, |
|
|
3691 |
target_type: Union[PredictionType, str] = PredictionType.DATA) |
|
|
3692 |
``` |
|
|
3693 |
|
|
|
3694 |
Calculate the loss given the model prediction, data sample, time, and mask. |
|
|
3695 |
|
|
|
3696 |
If target_type is FLOW loss = ||v_hat - (x1-x0)||**2 |
|
|
3697 |
If target_type is DATA loss = ||x1_hat - x1||**2 * 1 / (1 - t)**2 as the target vector field = x1 - x0 = (1/(1-t)) * x1 - xt where xt = tx1 - (1-t)x0. |
|
|
3698 |
This functions supports any cominbation of prediction_type and target_type in {DATA, FLOW}. |
|
|
3699 |
|
|
|
3700 |
**Arguments**: |
|
|
3701 |
|
|
|
3702 |
- `model_pred` _Tensor_ - The predicted output from the model. |
|
|
3703 |
- `target` _Tensor_ - The target output for the model prediction. |
|
|
3704 |
- `t` _Optional[Tensor], optional_ - The time for the model prediction. Defaults to None. |
|
|
3705 |
- `xt` _Optional[Tensor], optional_ - The interpolated data. Defaults to None. |
|
|
3706 |
- `mask` _Optional[Tensor], optional_ - The mask for the data point. Defaults to None. |
|
|
3707 |
- `target_type` _PredictionType, optional_ - The type of the target output. Defaults to PredictionType.DATA. |
|
|
3708 |
|
|
|
3709 |
|
|
|
3710 |
**Returns**: |
|
|
3711 |
|
|
|
3712 |
- `Tensor` - The calculated loss batch tensor. |
|
|
3713 |
|
|
|
3714 |
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatchervf_to_score"></a> |
|
|
3715 |
|
|
|
3716 |
#### vf\_to\_score |
|
|
3717 |
|
|
|
3718 |
```python |
|
|
3719 |
def vf_to_score(x_t: Tensor, v: Tensor, t: Tensor) -> Tensor |
|
|
3720 |
``` |
|
|
3721 |
|
|
|
3722 |
From Geffner et al. Computes score of noisy density given the vector field learned by flow matching. |
|
|
3723 |
|
|
|
3724 |
With our interpolation scheme these are related by |
|
|
3725 |
|
|
|
3726 |
v(x_t, t) = (1 / t) (x_t + scale_ref ** 2 * (1 - t) * s(x_t, t)), |
|
|
3727 |
|
|
|
3728 |
or equivalently, |
|
|
3729 |
|
|
|
3730 |
s(x_t, t) = (t * v(x_t, t) - x_t) / (scale_ref ** 2 * (1 - t)). |
|
|
3731 |
|
|
|
3732 |
with scale_ref = 1 |
|
|
3733 |
|
|
|
3734 |
**Arguments**: |
|
|
3735 |
|
|
|
3736 |
- `x_t` - Noisy sample, shape [*, dim] |
|
|
3737 |
- `v` - Vector field, shape [*, dim] |
|
|
3738 |
- `t` - Interpolation time, shape [*] (must be < 1) |
|
|
3739 |
|
|
|
3740 |
|
|
|
3741 |
**Returns**: |
|
|
3742 |
|
|
|
3743 |
Score of intermediate density, shape [*, dim]. |
|
|
3744 |
|
|
|
3745 |
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherget_gt"></a> |
|
|
3746 |
|
|
|
3747 |
#### get\_gt |
|
|
3748 |
|
|
|
3749 |
```python |
|
|
3750 |
def get_gt(t: Tensor, |
|
|
3751 |
mode: str = "tan", |
|
|
3752 |
param: float = 1.0, |
|
|
3753 |
clamp_val: Optional[float] = None, |
|
|
3754 |
eps: float = 1e-2) -> Tensor |
|
|
3755 |
``` |
|
|
3756 |
|
|
|
3757 |
From Geffner et al. Computes gt for different modes. |
|
|
3758 |
|
|
|
3759 |
**Arguments**: |
|
|
3760 |
|
|
|
3761 |
- `t` - times where we'll evaluate, covers [0, 1), shape [nsteps] |
|
|
3762 |
- `mode` - "us" or "tan" |
|
|
3763 |
- `param` - parameterized transformation |
|
|
3764 |
- `clamp_val` - value to clamp gt, no clamping if None |
|
|
3765 |
- `eps` - small value leave as it is |
|
|
3766 |
|
|
|
3767 |
<a id="mocointerpolantscontinuous_time"></a> |
|
|
3768 |
|
|
|
3769 |
# bionemo.moco.interpolants.continuous\_time |
|
|
3770 |
|
|
|
3771 |
<a id="mocointerpolants"></a> |
|
|
3772 |
|
|
|
3773 |
# bionemo.moco.interpolants |
|
|
3774 |
|
|
|
3775 |
<a id="mocointerpolantsbatch_augmentation"></a> |
|
|
3776 |
|
|
|
3777 |
# bionemo.moco.interpolants.batch\_augmentation |
|
|
3778 |
|
|
|
3779 |
<a id="mocointerpolantsbatch_augmentationBatchDataAugmentation"></a> |
|
|
3780 |
|
|
|
3781 |
## BatchDataAugmentation Objects |
|
|
3782 |
|
|
|
3783 |
```python |
|
|
3784 |
class BatchDataAugmentation() |
|
|
3785 |
``` |
|
|
3786 |
|
|
|
3787 |
Facilitates the creation of batch augmentation objects based on specified optimal transport types. |
|
|
3788 |
|
|
|
3789 |
**Arguments**: |
|
|
3790 |
|
|
|
3791 |
- `device` _str_ - The device to use for computations (e.g., 'cpu', 'cuda'). |
|
|
3792 |
- `num_threads` _int_ - The number of threads to utilize. |
|
|
3793 |
|
|
|
3794 |
<a id="mocointerpolantsbatch_augmentationBatchDataAugmentation__init__"></a> |
|
|
3795 |
|
|
|
3796 |
#### \_\_init\_\_ |
|
|
3797 |
|
|
|
3798 |
```python |
|
|
3799 |
def __init__(device, num_threads) |
|
|
3800 |
``` |
|
|
3801 |
|
|
|
3802 |
Initializes a BatchAugmentation instance. |
|
|
3803 |
|
|
|
3804 |
**Arguments**: |
|
|
3805 |
|
|
|
3806 |
- `device` _str_ - Device for computation. |
|
|
3807 |
- `num_threads` _int_ - Number of threads to use. |
|
|
3808 |
|
|
|
3809 |
<a id="mocointerpolantsbatch_augmentationBatchDataAugmentationcreate"></a> |
|
|
3810 |
|
|
|
3811 |
#### create |
|
|
3812 |
|
|
|
3813 |
```python |
|
|
3814 |
def create(method_type: AugmentationType) |
|
|
3815 |
``` |
|
|
3816 |
|
|
|
3817 |
Creates a batch augmentation object of the specified type. |
|
|
3818 |
|
|
|
3819 |
**Arguments**: |
|
|
3820 |
|
|
|
3821 |
- `method_type` _AugmentationType_ - The type of optimal transport method. |
|
|
3822 |
|
|
|
3823 |
|
|
|
3824 |
**Returns**: |
|
|
3825 |
|
|
|
3826 |
The augmentation object if the type is supported, otherwise **None**. |
|
|
3827 |
|
|
|
3828 |
<a id="mocointerpolantsdiscrete_timediscreted3pm"></a> |
|
|
3829 |
|
|
|
3830 |
# bionemo.moco.interpolants.discrete\_time.discrete.d3pm |
|
|
3831 |
|
|
|
3832 |
<a id="mocointerpolantsdiscrete_timediscreted3pmD3PM"></a> |
|
|
3833 |
|
|
|
3834 |
## D3PM Objects |
|
|
3835 |
|
|
|
3836 |
```python |
|
|
3837 |
class D3PM(Interpolant) |
|
|
3838 |
``` |
|
|
3839 |
|
|
|
3840 |
A Discrete Denoising Diffusion Probabilistic Model (D3PM) interpolant. |
|
|
3841 |
|
|
|
3842 |
<a id="mocointerpolantsdiscrete_timediscreted3pmD3PM__init__"></a> |
|
|
3843 |
|
|
|
3844 |
#### \_\_init\_\_ |
|
|
3845 |
|
|
|
3846 |
```python |
|
|
3847 |
def __init__(time_distribution: TimeDistribution, |
|
|
3848 |
prior_distribution: DiscretePriorDistribution, |
|
|
3849 |
noise_schedule: DiscreteNoiseSchedule, |
|
|
3850 |
device: str = "cpu", |
|
|
3851 |
last_time_idx: int = 0, |
|
|
3852 |
rng_generator: Optional[torch.Generator] = None) |
|
|
3853 |
``` |
|
|
3854 |
|
|
|
3855 |
Initializes the D3PM interpolant. |
|
|
3856 |
|
|
|
3857 |
**Arguments**: |
|
|
3858 |
|
|
|
3859 |
- `time_distribution` _TimeDistribution_ - The distribution of time steps, used to sample time points for the diffusion process. |
|
|
3860 |
- `prior_distribution` _PriorDistribution_ - The prior distribution of the variable, used as the starting point for the diffusion process. |
|
|
3861 |
- `noise_schedule` _DiscreteNoiseSchedule_ - The schedule of noise, defining the amount of noise added at each time step. |
|
|
3862 |
- `device` _str, optional_ - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu". |
|
|
3863 |
- `last_time_idx` _int, optional_ - The last time index to consider in the interpolation process. Defaults to 0. |
|
|
3864 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
3865 |
|
|
|
3866 |
<a id="mocointerpolantsdiscrete_timediscreted3pmD3PMinterpolate"></a> |
|
|
3867 |
|
|
|
3868 |
#### interpolate |
|
|
3869 |
|
|
|
3870 |
```python |
|
|
3871 |
def interpolate(data: Tensor, t: Tensor) |
|
|
3872 |
``` |
|
|
3873 |
|
|
|
3874 |
Interpolate using discrete interpolation method. |
|
|
3875 |
|
|
|
3876 |
This method implements Equation 2 from the D3PM paper (https://arxiv.org/pdf/2107.03006), which |
|
|
3877 |
calculates the interpolated discrete state `xt` at time `t` given the input data and noise |
|
|
3878 |
via q(xt|x0) = Cat(xt; p = x0*Qt_bar). |
|
|
3879 |
|
|
|
3880 |
**Arguments**: |
|
|
3881 |
|
|
|
3882 |
- `data` _Tensor_ - The input data to be interpolated. |
|
|
3883 |
- `t` _Tensor_ - The time step at which to interpolate. |
|
|
3884 |
|
|
|
3885 |
|
|
|
3886 |
**Returns**: |
|
|
3887 |
|
|
|
3888 |
- `Tensor` - The interpolated discrete state `xt` at time `t`. |
|
|
3889 |
|
|
|
3890 |
<a id="mocointerpolantsdiscrete_timediscreted3pmD3PMforward_process"></a> |
|
|
3891 |
|
|
|
3892 |
#### forward\_process |
|
|
3893 |
|
|
|
3894 |
```python |
|
|
3895 |
def forward_process(data: Tensor, t: Tensor) -> Tensor |
|
|
3896 |
``` |
|
|
3897 |
|
|
|
3898 |
Apply the forward process to the data at time t. |
|
|
3899 |
|
|
|
3900 |
**Arguments**: |
|
|
3901 |
|
|
|
3902 |
- `data` _Tensor_ - target discrete ids |
|
|
3903 |
- `t` _Tensor_ - time |
|
|
3904 |
|
|
|
3905 |
|
|
|
3906 |
**Returns**: |
|
|
3907 |
|
|
|
3908 |
- `Tensor` - x(t) after applying the forward process |
|
|
3909 |
|
|
|
3910 |
<a id="mocointerpolantsdiscrete_timediscreted3pmD3PMstep"></a> |
|
|
3911 |
|
|
|
3912 |
#### step |
|
|
3913 |
|
|
|
3914 |
```python |
|
|
3915 |
def step(model_out: Tensor, |
|
|
3916 |
t: Tensor, |
|
|
3917 |
xt: Tensor, |
|
|
3918 |
mask: Optional[Tensor] = None, |
|
|
3919 |
temperature: Float = 1.0, |
|
|
3920 |
model_out_is_logits: bool = True) |
|
|
3921 |
``` |
|
|
3922 |
|
|
|
3923 |
Perform a single step in the discrete interpolant method, transitioning from the current discrete state `xt` at time `t` to the next state. |
|
|
3924 |
|
|
|
3925 |
This step involves: |
|
|
3926 |
|
|
|
3927 |
1. Computing the predicted q-posterior logits using the model output `model_out` and the current state `xt` at time `t`. |
|
|
3928 |
2. Sampling the next state from the predicted q-posterior distribution using the Gumbel-Softmax trick. |
|
|
3929 |
|
|
|
3930 |
**Arguments**: |
|
|
3931 |
|
|
|
3932 |
- `model_out` _Tensor_ - The output of the model at the current time step, which is used to compute the predicted q-posterior logits. |
|
|
3933 |
- `t` _Tensor_ - The current time step, which is used to index into the transition matrices and compute the predicted q-posterior logits. |
|
|
3934 |
- `xt` _Tensor_ - The current discrete state at time `t`, which is used to compute the predicted q-posterior logits and sample the next state. |
|
|
3935 |
- `mask` _Optional[Tensor], optional_ - An optional mask to apply to the next state, which can be used to mask out certain tokens or regions. Defaults to None. |
|
|
3936 |
- `temperature` _Float, optional_ - The temperature to use for the Gumbel-Softmax trick, which controls the randomness of the sampling process. Defaults to 1.0. |
|
|
3937 |
- `model_out_is_logits` _bool, optional_ - A flag indicating whether the model output is already in logits form. If True, the output is assumed to be logits; otherwise, it is converted to logits. Defaults to True. |
|
|
3938 |
|
|
|
3939 |
|
|
|
3940 |
**Returns**: |
|
|
3941 |
|
|
|
3942 |
- `Tensor` - The next discrete state at time `t-1`. |
|
|
3943 |
|
|
|
3944 |
<a id="mocointerpolantsdiscrete_timediscreted3pmD3PMloss"></a> |
|
|
3945 |
|
|
|
3946 |
#### loss |
|
|
3947 |
|
|
|
3948 |
```python |
|
|
3949 |
def loss(logits: Tensor, |
|
|
3950 |
target: Tensor, |
|
|
3951 |
xt: Tensor, |
|
|
3952 |
time: Tensor, |
|
|
3953 |
mask: Optional[Tensor] = None, |
|
|
3954 |
vb_scale: Float = 0.0) |
|
|
3955 |
``` |
|
|
3956 |
|
|
|
3957 |
Calculate the cross-entropy loss between the model prediction and the target output. |
|
|
3958 |
|
|
|
3959 |
The loss is calculated between the batch x node x class logits and the target batch x node. If a mask is provided, the loss is |
|
|
3960 |
calculated only for the non-masked elements. Additionally, if vb_scale is greater than 0, the variational lower bound loss is |
|
|
3961 |
calculated and added to the total loss. |
|
|
3962 |
|
|
|
3963 |
**Arguments**: |
|
|
3964 |
|
|
|
3965 |
- `logits` _Tensor_ - The predicted output from the model, with shape batch x node x class. |
|
|
3966 |
- `target` _Tensor_ - The target output for the model prediction, with shape batch x node. |
|
|
3967 |
- `xt` _Tensor_ - The current data point. |
|
|
3968 |
- `time` _Tensor_ - The time at which the loss is calculated. |
|
|
3969 |
- `mask` _Optional[Tensor], optional_ - The mask for the data point. Defaults to None. |
|
|
3970 |
- `vb_scale` _Float, optional_ - The scale factor for the variational lower bound loss. Defaults to 0.0. |
|
|
3971 |
|
|
|
3972 |
|
|
|
3973 |
**Returns**: |
|
|
3974 |
|
|
|
3975 |
- `Tensor` - The calculated loss tensor. If aggregate is True, the loss and variational lower bound loss are aggregated and |
|
|
3976 |
returned as a single tensor. Otherwise, the loss and variational lower bound loss are returned as separate tensors. |
|
|
3977 |
|
|
|
3978 |
<a id="mocointerpolantsdiscrete_timediscrete"></a> |
|
|
3979 |
|
|
|
3980 |
# bionemo.moco.interpolants.discrete\_time.discrete |
|
|
3981 |
|
|
|
3982 |
<a id="mocointerpolantsdiscrete_timecontinuousddpm"></a> |
|
|
3983 |
|
|
|
3984 |
# bionemo.moco.interpolants.discrete\_time.continuous.ddpm |
|
|
3985 |
|
|
|
3986 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPM"></a> |
|
|
3987 |
|
|
|
3988 |
## DDPM Objects |
|
|
3989 |
|
|
|
3990 |
```python |
|
|
3991 |
class DDPM(Interpolant) |
|
|
3992 |
``` |
|
|
3993 |
|
|
|
3994 |
A Denoising Diffusion Probabilistic Model (DDPM) interpolant. |
|
|
3995 |
|
|
|
3996 |
------- |
|
|
3997 |
|
|
|
3998 |
**Examples**: |
|
|
3999 |
|
|
|
4000 |
```python |
|
|
4001 |
>>> import torch |
|
|
4002 |
>>> from bionemo.bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior |
|
|
4003 |
>>> from bionemo.bionemo.moco.distributions.time.uniform import UniformTimeDistribution |
|
|
4004 |
>>> from bionemo.bionemo.moco.interpolants.discrete_time.continuous.ddpm import DDPM |
|
|
4005 |
>>> from bionemo.bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteCosineNoiseSchedule |
|
|
4006 |
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import DiscreteLinearInferenceSchedule |
|
|
4007 |
|
|
|
4008 |
|
|
|
4009 |
ddpm = DDPM( |
|
|
4010 |
time_distribution = UniformTimeDistribution(discrete_time = True,...), |
|
|
4011 |
prior_distribution = GaussianPrior(...), |
|
|
4012 |
noise_schedule = DiscreteCosineNoiseSchedule(...), |
|
|
4013 |
) |
|
|
4014 |
model = Model(...) |
|
|
4015 |
|
|
|
4016 |
# Training |
|
|
4017 |
for epoch in range(1000): |
|
|
4018 |
data = data_loader.get(...) |
|
|
4019 |
time = ddpm.sample_time(batch_size) |
|
|
4020 |
noise = ddpm.sample_prior(data.shape) |
|
|
4021 |
xt = ddpm.interpolate(data, noise, time) |
|
|
4022 |
|
|
|
4023 |
x_pred = model(xt, time) |
|
|
4024 |
loss = ddpm.loss(x_pred, data, time) |
|
|
4025 |
loss.backward() |
|
|
4026 |
|
|
|
4027 |
# Generation |
|
|
4028 |
x_pred = ddpm.sample_prior(data.shape) |
|
|
4029 |
for t in DiscreteLinearTimeSchedule(...).generate_schedule(): |
|
|
4030 |
time = torch.full((batch_size,), t) |
|
|
4031 |
x_hat = model(x_pred, time) |
|
|
4032 |
x_pred = ddpm.step(x_hat, time, x_pred) |
|
|
4033 |
return x_pred |
|
|
4034 |
|
|
|
4035 |
``` |
|
|
4036 |
|
|
|
4037 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPM__init__"></a> |
|
|
4038 |
|
|
|
4039 |
#### \_\_init\_\_ |
|
|
4040 |
|
|
|
4041 |
```python |
|
|
4042 |
def __init__(time_distribution: TimeDistribution, |
|
|
4043 |
prior_distribution: PriorDistribution, |
|
|
4044 |
noise_schedule: DiscreteNoiseSchedule, |
|
|
4045 |
prediction_type: Union[PredictionType, str] = PredictionType.DATA, |
|
|
4046 |
device: Union[str, torch.device] = "cpu", |
|
|
4047 |
last_time_idx: int = 0, |
|
|
4048 |
rng_generator: Optional[torch.Generator] = None) |
|
|
4049 |
``` |
|
|
4050 |
|
|
|
4051 |
Initializes the DDPM interpolant. |
|
|
4052 |
|
|
|
4053 |
**Arguments**: |
|
|
4054 |
|
|
|
4055 |
- `time_distribution` _TimeDistribution_ - The distribution of time steps, used to sample time points for the diffusion process. |
|
|
4056 |
- `prior_distribution` _PriorDistribution_ - The prior distribution of the variable, used as the starting point for the diffusion process. |
|
|
4057 |
- `noise_schedule` _DiscreteNoiseSchedule_ - The schedule of noise, defining the amount of noise added at each time step. |
|
|
4058 |
- `prediction_type` _PredictionType_ - The type of prediction, either "data" or another type. Defaults to "data". |
|
|
4059 |
- `device` _str_ - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu". |
|
|
4060 |
- `last_time_idx` _int, optional_ - The last time index for discrete time. Set to 0 if discrete time is T-1, ..., 0 or 1 if T, ..., 1. Defaults to 0. |
|
|
4061 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
4062 |
|
|
|
4063 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMforward_data_schedule"></a> |
|
|
4064 |
|
|
|
4065 |
#### forward\_data\_schedule |
|
|
4066 |
|
|
|
4067 |
```python |
|
|
4068 |
@property |
|
|
4069 |
def forward_data_schedule() -> torch.Tensor |
|
|
4070 |
``` |
|
|
4071 |
|
|
|
4072 |
Returns the forward data schedule. |
|
|
4073 |
|
|
|
4074 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMforward_noise_schedule"></a> |
|
|
4075 |
|
|
|
4076 |
#### forward\_noise\_schedule |
|
|
4077 |
|
|
|
4078 |
```python |
|
|
4079 |
@property |
|
|
4080 |
def forward_noise_schedule() -> torch.Tensor |
|
|
4081 |
``` |
|
|
4082 |
|
|
|
4083 |
Returns the forward noise schedule. |
|
|
4084 |
|
|
|
4085 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMreverse_data_schedule"></a> |
|
|
4086 |
|
|
|
4087 |
#### reverse\_data\_schedule |
|
|
4088 |
|
|
|
4089 |
```python |
|
|
4090 |
@property |
|
|
4091 |
def reverse_data_schedule() -> torch.Tensor |
|
|
4092 |
``` |
|
|
4093 |
|
|
|
4094 |
Returns the reverse data schedule. |
|
|
4095 |
|
|
|
4096 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMreverse_noise_schedule"></a> |
|
|
4097 |
|
|
|
4098 |
#### reverse\_noise\_schedule |
|
|
4099 |
|
|
|
4100 |
```python |
|
|
4101 |
@property |
|
|
4102 |
def reverse_noise_schedule() -> torch.Tensor |
|
|
4103 |
``` |
|
|
4104 |
|
|
|
4105 |
Returns the reverse noise schedule. |
|
|
4106 |
|
|
|
4107 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMlog_var"></a> |
|
|
4108 |
|
|
|
4109 |
#### log\_var |
|
|
4110 |
|
|
|
4111 |
```python |
|
|
4112 |
@property |
|
|
4113 |
def log_var() -> torch.Tensor |
|
|
4114 |
``` |
|
|
4115 |
|
|
|
4116 |
Returns the log variance. |
|
|
4117 |
|
|
|
4118 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMalpha_bar"></a> |
|
|
4119 |
|
|
|
4120 |
#### alpha\_bar |
|
|
4121 |
|
|
|
4122 |
```python |
|
|
4123 |
@property |
|
|
4124 |
def alpha_bar() -> torch.Tensor |
|
|
4125 |
``` |
|
|
4126 |
|
|
|
4127 |
Returns the alpha bar values. |
|
|
4128 |
|
|
|
4129 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMalpha_bar_prev"></a> |
|
|
4130 |
|
|
|
4131 |
#### alpha\_bar\_prev |
|
|
4132 |
|
|
|
4133 |
```python |
|
|
4134 |
@property |
|
|
4135 |
def alpha_bar_prev() -> torch.Tensor |
|
|
4136 |
``` |
|
|
4137 |
|
|
|
4138 |
Returns the previous alpha bar values. |
|
|
4139 |
|
|
|
4140 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMinterpolate"></a> |
|
|
4141 |
|
|
|
4142 |
#### interpolate |
|
|
4143 |
|
|
|
4144 |
```python |
|
|
4145 |
def interpolate(data: Tensor, t: Tensor, noise: Tensor) |
|
|
4146 |
``` |
|
|
4147 |
|
|
|
4148 |
Get x(t) with given time t from noise and data. |
|
|
4149 |
|
|
|
4150 |
**Arguments**: |
|
|
4151 |
|
|
|
4152 |
- `data` _Tensor_ - target |
|
|
4153 |
- `t` _Tensor_ - time |
|
|
4154 |
- `noise` _Tensor_ - noise from prior() |
|
|
4155 |
|
|
|
4156 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMforward_process"></a> |
|
|
4157 |
|
|
|
4158 |
#### forward\_process |
|
|
4159 |
|
|
|
4160 |
```python |
|
|
4161 |
def forward_process(data: Tensor, t: Tensor, noise: Optional[Tensor] = None) |
|
|
4162 |
``` |
|
|
4163 |
|
|
|
4164 |
Get x(t) with given time t from noise and data. |
|
|
4165 |
|
|
|
4166 |
**Arguments**: |
|
|
4167 |
|
|
|
4168 |
- `data` _Tensor_ - target |
|
|
4169 |
- `t` _Tensor_ - time |
|
|
4170 |
- `noise` _Tensor, optional_ - noise from prior(). Defaults to None. |
|
|
4171 |
|
|
|
4172 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMprocess_data_prediction"></a> |
|
|
4173 |
|
|
|
4174 |
#### process\_data\_prediction |
|
|
4175 |
|
|
|
4176 |
```python |
|
|
4177 |
def process_data_prediction(model_output: Tensor, sample: Tensor, t: Tensor) |
|
|
4178 |
``` |
|
|
4179 |
|
|
|
4180 |
Converts the model output to a data prediction based on the prediction type. |
|
|
4181 |
|
|
|
4182 |
This conversion stems from the Progressive Distillation for Fast Sampling of Diffusion Models https://arxiv.org/pdf/2202.00512. |
|
|
4183 |
Given the model output and the sample, we convert the output to a data prediction based on the prediction type. |
|
|
4184 |
The conversion formulas are as follows: |
|
|
4185 |
- For "noise" prediction type: `pred_data = (sample - noise_scale * model_output) / data_scale` |
|
|
4186 |
- For "data" prediction type: `pred_data = model_output` |
|
|
4187 |
- For "v_prediction" prediction type: `pred_data = data_scale * sample - noise_scale * model_output` |
|
|
4188 |
|
|
|
4189 |
**Arguments**: |
|
|
4190 |
|
|
|
4191 |
- `model_output` _Tensor_ - The output of the model. |
|
|
4192 |
- `sample` _Tensor_ - The input sample. |
|
|
4193 |
- `t` _Tensor_ - The time step. |
|
|
4194 |
|
|
|
4195 |
|
|
|
4196 |
**Returns**: |
|
|
4197 |
|
|
|
4198 |
The data prediction based on the prediction type. |
|
|
4199 |
|
|
|
4200 |
|
|
|
4201 |
**Raises**: |
|
|
4202 |
|
|
|
4203 |
- `ValueError` - If the prediction type is not one of "noise", "data", or "v_prediction". |
|
|
4204 |
|
|
|
4205 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMprocess_noise_prediction"></a> |
|
|
4206 |
|
|
|
4207 |
#### process\_noise\_prediction |
|
|
4208 |
|
|
|
4209 |
```python |
|
|
4210 |
def process_noise_prediction(model_output, sample, t) |
|
|
4211 |
``` |
|
|
4212 |
|
|
|
4213 |
Do the same as process_data_prediction but take the model output and convert to nosie. |
|
|
4214 |
|
|
|
4215 |
**Arguments**: |
|
|
4216 |
|
|
|
4217 |
- `model_output` - The output of the model. |
|
|
4218 |
- `sample` - The input sample. |
|
|
4219 |
- `t` - The time step. |
|
|
4220 |
|
|
|
4221 |
|
|
|
4222 |
**Returns**: |
|
|
4223 |
|
|
|
4224 |
The input as noise if the prediction type is "noise". |
|
|
4225 |
|
|
|
4226 |
|
|
|
4227 |
**Raises**: |
|
|
4228 |
|
|
|
4229 |
- `ValueError` - If the prediction type is not "noise". |
|
|
4230 |
|
|
|
4231 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMcalculate_velocity"></a> |
|
|
4232 |
|
|
|
4233 |
#### calculate\_velocity |
|
|
4234 |
|
|
|
4235 |
```python |
|
|
4236 |
def calculate_velocity(data: Tensor, t: Tensor, noise: Tensor) -> Tensor |
|
|
4237 |
``` |
|
|
4238 |
|
|
|
4239 |
Calculate the velocity term given the data, time step, and noise. |
|
|
4240 |
|
|
|
4241 |
**Arguments**: |
|
|
4242 |
|
|
|
4243 |
- `data` _Tensor_ - The input data. |
|
|
4244 |
- `t` _Tensor_ - The current time step. |
|
|
4245 |
- `noise` _Tensor_ - The noise term. |
|
|
4246 |
|
|
|
4247 |
|
|
|
4248 |
**Returns**: |
|
|
4249 |
|
|
|
4250 |
- `Tensor` - The calculated velocity term. |
|
|
4251 |
|
|
|
4252 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMstep"></a> |
|
|
4253 |
|
|
|
4254 |
#### step |
|
|
4255 |
|
|
|
4256 |
```python |
|
|
4257 |
@torch.no_grad() |
|
|
4258 |
def step(model_out: Tensor, |
|
|
4259 |
t: Tensor, |
|
|
4260 |
xt: Tensor, |
|
|
4261 |
mask: Optional[Tensor] = None, |
|
|
4262 |
center: Bool = False, |
|
|
4263 |
temperature: Float = 1.0) |
|
|
4264 |
``` |
|
|
4265 |
|
|
|
4266 |
Do one step integration. |
|
|
4267 |
|
|
|
4268 |
**Arguments**: |
|
|
4269 |
|
|
|
4270 |
- `model_out` _Tensor_ - The output of the model. |
|
|
4271 |
- `t` _Tensor_ - The current time step. |
|
|
4272 |
- `xt` _Tensor_ - The current data point. |
|
|
4273 |
- `mask` _Optional[Tensor], optional_ - An optional mask to apply to the data. Defaults to None. |
|
|
4274 |
- `center` _bool, optional_ - Whether to center the data. Defaults to False. |
|
|
4275 |
- `temperature` _Float, optional_ - The temperature parameter for low temperature sampling. Defaults to 1.0. |
|
|
4276 |
|
|
|
4277 |
|
|
|
4278 |
**Notes**: |
|
|
4279 |
|
|
|
4280 |
The temperature parameter controls the level of randomness in the sampling process. A temperature of 1.0 corresponds to standard diffusion sampling, while lower temperatures (e.g. 0.5, 0.2) result in less random and more deterministic samples. This can be useful for tasks that require more control over the generation process. |
|
|
4281 |
|
|
|
4282 |
Note for discrete time we sample from [T-1, ..., 1, 0] for T steps so we sample t = 0 hence the mask. |
|
|
4283 |
For continuous time we start from [1, 1 -dt, ..., dt] for T steps where s = t - 1 when t = 0 i.e dt is then 0 |
|
|
4284 |
|
|
|
4285 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMstep_noise"></a> |
|
|
4286 |
|
|
|
4287 |
#### step\_noise |
|
|
4288 |
|
|
|
4289 |
```python |
|
|
4290 |
def step_noise(model_out: Tensor, |
|
|
4291 |
t: Tensor, |
|
|
4292 |
xt: Tensor, |
|
|
4293 |
mask: Optional[Tensor] = None, |
|
|
4294 |
center: Bool = False, |
|
|
4295 |
temperature: Float = 1.0) |
|
|
4296 |
``` |
|
|
4297 |
|
|
|
4298 |
Do one step integration. |
|
|
4299 |
|
|
|
4300 |
**Arguments**: |
|
|
4301 |
|
|
|
4302 |
- `model_out` _Tensor_ - The output of the model. |
|
|
4303 |
- `t` _Tensor_ - The current time step. |
|
|
4304 |
- `xt` _Tensor_ - The current data point. |
|
|
4305 |
- `mask` _Optional[Tensor], optional_ - An optional mask to apply to the data. Defaults to None. |
|
|
4306 |
- `center` _bool, optional_ - Whether to center the data. Defaults to False. |
|
|
4307 |
- `temperature` _Float, optional_ - The temperature parameter for low temperature sampling. Defaults to 1.0. |
|
|
4308 |
|
|
|
4309 |
|
|
|
4310 |
**Notes**: |
|
|
4311 |
|
|
|
4312 |
The temperature parameter controls the level of randomness in the sampling process. |
|
|
4313 |
A temperature of 1.0 corresponds to standard diffusion sampling, while lower temperatures (e.g. 0.5, 0.2) |
|
|
4314 |
result in less random and more deterministic samples. This can be useful for tasks |
|
|
4315 |
that require more control over the generation process. |
|
|
4316 |
|
|
|
4317 |
Note for discrete time we sample from [T-1, ..., 1, 0] for T steps so we sample t = 0 hence the mask. |
|
|
4318 |
For continuous time we start from [1, 1 -dt, ..., dt] for T steps where s = t - 1 when t = 0 i.e dt is then 0 |
|
|
4319 |
|
|
|
4320 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMscore"></a> |
|
|
4321 |
|
|
|
4322 |
#### score |
|
|
4323 |
|
|
|
4324 |
```python |
|
|
4325 |
def score(x_hat: Tensor, xt: Tensor, t: Tensor) |
|
|
4326 |
``` |
|
|
4327 |
|
|
|
4328 |
Converts the data prediction to the estimated score function. |
|
|
4329 |
|
|
|
4330 |
**Arguments**: |
|
|
4331 |
|
|
|
4332 |
- `x_hat` _Tensor_ - The predicted data point. |
|
|
4333 |
- `xt` _Tensor_ - The current data point. |
|
|
4334 |
- `t` _Tensor_ - The time step. |
|
|
4335 |
|
|
|
4336 |
|
|
|
4337 |
**Returns**: |
|
|
4338 |
|
|
|
4339 |
The estimated score function. |
|
|
4340 |
|
|
|
4341 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMstep_ddim"></a> |
|
|
4342 |
|
|
|
4343 |
#### step\_ddim |
|
|
4344 |
|
|
|
4345 |
```python |
|
|
4346 |
def step_ddim(model_out: Tensor, |
|
|
4347 |
t: Tensor, |
|
|
4348 |
xt: Tensor, |
|
|
4349 |
mask: Optional[Tensor] = None, |
|
|
4350 |
eta: Float = 0.0, |
|
|
4351 |
center: Bool = False) |
|
|
4352 |
``` |
|
|
4353 |
|
|
|
4354 |
Do one step of DDIM sampling. |
|
|
4355 |
|
|
|
4356 |
**Arguments**: |
|
|
4357 |
|
|
|
4358 |
- `model_out` _Tensor_ - output of the model |
|
|
4359 |
- `t` _Tensor_ - current time step |
|
|
4360 |
- `xt` _Tensor_ - current data point |
|
|
4361 |
- `mask` _Optional[Tensor], optional_ - mask for the data point. Defaults to None. |
|
|
4362 |
- `eta` _Float, optional_ - DDIM sampling parameter. Defaults to 0.0. |
|
|
4363 |
- `center` _Bool, optional_ - whether to center the data point. Defaults to False. |
|
|
4364 |
|
|
|
4365 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMset_loss_weight_fn"></a> |
|
|
4366 |
|
|
|
4367 |
#### set\_loss\_weight\_fn |
|
|
4368 |
|
|
|
4369 |
```python |
|
|
4370 |
def set_loss_weight_fn(fn) |
|
|
4371 |
``` |
|
|
4372 |
|
|
|
4373 |
Sets the loss_weight attribute of the instance to the given function. |
|
|
4374 |
|
|
|
4375 |
**Arguments**: |
|
|
4376 |
|
|
|
4377 |
- `fn` - The function to set as the loss_weight attribute. This function should take three arguments: raw_loss, t, and weight_type. |
|
|
4378 |
|
|
|
4379 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMloss_weight"></a> |
|
|
4380 |
|
|
|
4381 |
#### loss\_weight |
|
|
4382 |
|
|
|
4383 |
```python |
|
|
4384 |
def loss_weight(raw_loss: Tensor, t: Optional[Tensor], |
|
|
4385 |
weight_type: str) -> Tensor |
|
|
4386 |
``` |
|
|
4387 |
|
|
|
4388 |
Calculates the weight for the loss based on the given weight type. |
|
|
4389 |
|
|
|
4390 |
These data_to_noise loss weights is derived in Equation (9) of https://arxiv.org/pdf/2202.00512. |
|
|
4391 |
|
|
|
4392 |
**Arguments**: |
|
|
4393 |
|
|
|
4394 |
- `raw_loss` _Tensor_ - The raw loss calculated from the model prediction and target. |
|
|
4395 |
- `t` _Tensor_ - The time step. |
|
|
4396 |
- `weight_type` _str_ - The type of weight to use. Can be "ones" or "data_to_noise" or "noise_to_data". |
|
|
4397 |
|
|
|
4398 |
|
|
|
4399 |
**Returns**: |
|
|
4400 |
|
|
|
4401 |
- `Tensor` - The weight for the loss. |
|
|
4402 |
|
|
|
4403 |
|
|
|
4404 |
**Raises**: |
|
|
4405 |
|
|
|
4406 |
- `ValueError` - If the weight type is not recognized. |
|
|
4407 |
|
|
|
4408 |
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMloss"></a> |
|
|
4409 |
|
|
|
4410 |
#### loss |
|
|
4411 |
|
|
|
4412 |
```python |
|
|
4413 |
def loss(model_pred: Tensor, |
|
|
4414 |
target: Tensor, |
|
|
4415 |
t: Optional[Tensor] = None, |
|
|
4416 |
mask: Optional[Tensor] = None, |
|
|
4417 |
weight_type: Literal["ones", "data_to_noise", |
|
|
4418 |
"noise_to_data"] = "ones") |
|
|
4419 |
``` |
|
|
4420 |
|
|
|
4421 |
Calculate the loss given the model prediction, data sample, and time. |
|
|
4422 |
|
|
|
4423 |
The default weight_type is "ones" meaning no change / multiplying by all ones. |
|
|
4424 |
data_to_noise is available to scale the data MSE loss into the appropriate loss that is theoretically equivalent |
|
|
4425 |
to noise prediction. noise_to_data is provided for a similar reason for completeness. |
|
|
4426 |
|
|
|
4427 |
**Arguments**: |
|
|
4428 |
|
|
|
4429 |
- `model_pred` _Tensor_ - The predicted output from the model. |
|
|
4430 |
- `target` _Tensor_ - The target output for the model prediction. |
|
|
4431 |
- `t` _Tensor_ - The time at which the loss is calculated. |
|
|
4432 |
- `mask` _Optional[Tensor], optional_ - The mask for the data point. Defaults to None. |
|
|
4433 |
- `weight_type` _Literal["ones", "data_to_noise", "noise_to_data"]_ - The type of weight to use for the loss. Defaults to "ones". |
|
|
4434 |
|
|
|
4435 |
|
|
|
4436 |
**Returns**: |
|
|
4437 |
|
|
|
4438 |
- `Tensor` - The calculated loss batch tensor. |
|
|
4439 |
|
|
|
4440 |
<a id="mocointerpolantsdiscrete_timecontinuous"></a> |
|
|
4441 |
|
|
|
4442 |
# bionemo.moco.interpolants.discrete\_time.continuous |
|
|
4443 |
|
|
|
4444 |
<a id="mocointerpolantsdiscrete_time"></a> |
|
|
4445 |
|
|
|
4446 |
# bionemo.moco.interpolants.discrete\_time |
|
|
4447 |
|
|
|
4448 |
<a id="mocointerpolantsdiscrete_timeutils"></a> |
|
|
4449 |
|
|
|
4450 |
# bionemo.moco.interpolants.discrete\_time.utils |
|
|
4451 |
|
|
|
4452 |
<a id="mocointerpolantsdiscrete_timeutilssafe_index"></a> |
|
|
4453 |
|
|
|
4454 |
#### safe\_index |
|
|
4455 |
|
|
|
4456 |
```python |
|
|
4457 |
def safe_index(tensor: Tensor, index: Tensor, device: Optional[torch.device]) |
|
|
4458 |
``` |
|
|
4459 |
|
|
|
4460 |
Safely indexes a tensor using a given index and returns the result on a specified device. |
|
|
4461 |
|
|
|
4462 |
Note can implement forcing with return tensor[index.to(tensor.device)].to(device) but has costly migration. |
|
|
4463 |
|
|
|
4464 |
**Arguments**: |
|
|
4465 |
|
|
|
4466 |
- `tensor` _Tensor_ - The tensor to be indexed. |
|
|
4467 |
- `index` _Tensor_ - The index to use for indexing the tensor. |
|
|
4468 |
- `device` _torch.device_ - The device on which the result should be returned. |
|
|
4469 |
|
|
|
4470 |
|
|
|
4471 |
**Returns**: |
|
|
4472 |
|
|
|
4473 |
- `Tensor` - The indexed tensor on the specified device. |
|
|
4474 |
|
|
|
4475 |
|
|
|
4476 |
**Raises**: |
|
|
4477 |
|
|
|
4478 |
- `ValueError` - If tensor, index are not all on the same device. |
|
|
4479 |
|
|
|
4480 |
<a id="mocointerpolantsbase_interpolant"></a> |
|
|
4481 |
|
|
|
4482 |
# bionemo.moco.interpolants.base\_interpolant |
|
|
4483 |
|
|
|
4484 |
<a id="mocointerpolantsbase_interpolantstring_to_enum"></a> |
|
|
4485 |
|
|
|
4486 |
#### string\_to\_enum |
|
|
4487 |
|
|
|
4488 |
```python |
|
|
4489 |
def string_to_enum(value: Union[str, AnyEnum], |
|
|
4490 |
enum_type: Type[AnyEnum]) -> AnyEnum |
|
|
4491 |
``` |
|
|
4492 |
|
|
|
4493 |
Converts a string to an enum value of the specified type. If the input is already an enum instance, it is returned as-is. |
|
|
4494 |
|
|
|
4495 |
**Arguments**: |
|
|
4496 |
|
|
|
4497 |
- `value` _Union[str, E]_ - The string to convert or an existing enum instance. |
|
|
4498 |
- `enum_type` _Type[E]_ - The enum type to convert to. |
|
|
4499 |
|
|
|
4500 |
|
|
|
4501 |
**Returns**: |
|
|
4502 |
|
|
|
4503 |
- `E` - The corresponding enum value. |
|
|
4504 |
|
|
|
4505 |
|
|
|
4506 |
**Raises**: |
|
|
4507 |
|
|
|
4508 |
- `ValueError` - If the string does not correspond to any enum member. |
|
|
4509 |
|
|
|
4510 |
<a id="mocointerpolantsbase_interpolantpad_like"></a> |
|
|
4511 |
|
|
|
4512 |
#### pad\_like |
|
|
4513 |
|
|
|
4514 |
```python |
|
|
4515 |
def pad_like(source: Tensor, target: Tensor) -> Tensor |
|
|
4516 |
``` |
|
|
4517 |
|
|
|
4518 |
Pads the dimensions of the source tensor to match the dimensions of the target tensor. |
|
|
4519 |
|
|
|
4520 |
**Arguments**: |
|
|
4521 |
|
|
|
4522 |
- `source` _Tensor_ - The tensor to be padded. |
|
|
4523 |
- `target` _Tensor_ - The tensor that the source tensor should match in dimensions. |
|
|
4524 |
|
|
|
4525 |
|
|
|
4526 |
**Returns**: |
|
|
4527 |
|
|
|
4528 |
- `Tensor` - The padded source tensor. |
|
|
4529 |
|
|
|
4530 |
|
|
|
4531 |
**Raises**: |
|
|
4532 |
|
|
|
4533 |
- `ValueError` - If the source tensor has more dimensions than the target tensor. |
|
|
4534 |
|
|
|
4535 |
|
|
|
4536 |
**Example**: |
|
|
4537 |
|
|
|
4538 |
>>> source = torch.tensor([1, 2, 3]) # shape: (3,) |
|
|
4539 |
>>> target = torch.tensor([[1, 2], [4, 5], [7, 8]]) # shape: (3, 2) |
|
|
4540 |
>>> padded_source = pad_like(source, target) # shape: (3, 1) |
|
|
4541 |
|
|
|
4542 |
<a id="mocointerpolantsbase_interpolantPredictionType"></a> |
|
|
4543 |
|
|
|
4544 |
## PredictionType Objects |
|
|
4545 |
|
|
|
4546 |
```python |
|
|
4547 |
class PredictionType(Enum) |
|
|
4548 |
``` |
|
|
4549 |
|
|
|
4550 |
An enumeration representing the type of prediction a Denoising Diffusion Probabilistic Model (DDPM) can be used for. |
|
|
4551 |
|
|
|
4552 |
DDPMs are versatile models that can be utilized for various prediction tasks, including: |
|
|
4553 |
|
|
|
4554 |
- **Data**: Predicting the original data distribution from a noisy input. |
|
|
4555 |
- **Noise**: Predicting the noise that was added to the original data to obtain the input. |
|
|
4556 |
- **Velocity**: Predicting the velocity or rate of change of the data, particularly useful for modeling temporal dynamics. |
|
|
4557 |
|
|
|
4558 |
These prediction types can be used to train neural networks for specific tasks, such as denoising, image synthesis, or time-series forecasting. |
|
|
4559 |
|
|
|
4560 |
<a id="mocointerpolantsbase_interpolantInterpolant"></a> |
|
|
4561 |
|
|
|
4562 |
## Interpolant Objects |
|
|
4563 |
|
|
|
4564 |
```python |
|
|
4565 |
class Interpolant(ABC) |
|
|
4566 |
``` |
|
|
4567 |
|
|
|
4568 |
An abstract base class representing an Interpolant. |
|
|
4569 |
|
|
|
4570 |
This class serves as a foundation for creating interpolants that can be used |
|
|
4571 |
in various applications, providing a basic structure and interface for |
|
|
4572 |
interpolation-related operations. |
|
|
4573 |
|
|
|
4574 |
<a id="mocointerpolantsbase_interpolantInterpolant__init__"></a> |
|
|
4575 |
|
|
|
4576 |
#### \_\_init\_\_ |
|
|
4577 |
|
|
|
4578 |
```python |
|
|
4579 |
def __init__(time_distribution: TimeDistribution, |
|
|
4580 |
prior_distribution: PriorDistribution, |
|
|
4581 |
device: Union[str, torch.device] = "cpu", |
|
|
4582 |
rng_generator: Optional[torch.Generator] = None) |
|
|
4583 |
``` |
|
|
4584 |
|
|
|
4585 |
Initializes the Interpolant class. |
|
|
4586 |
|
|
|
4587 |
**Arguments**: |
|
|
4588 |
|
|
|
4589 |
- `time_distribution` _TimeDistribution_ - The distribution of time steps. |
|
|
4590 |
- `prior_distribution` _PriorDistribution_ - The prior distribution of the variable. |
|
|
4591 |
- `device` _Union[str, torch.device], optional_ - The device on which to operate. Defaults to "cpu". |
|
|
4592 |
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None. |
|
|
4593 |
|
|
|
4594 |
<a id="mocointerpolantsbase_interpolantInterpolantinterpolate"></a> |
|
|
4595 |
|
|
|
4596 |
#### interpolate |
|
|
4597 |
|
|
|
4598 |
```python |
|
|
4599 |
@abstractmethod |
|
|
4600 |
def interpolate(*args, **kwargs) -> Tensor |
|
|
4601 |
``` |
|
|
4602 |
|
|
|
4603 |
Get x(t) with given time t from noise and data. |
|
|
4604 |
|
|
|
4605 |
Interpolate between x0 and x1 at the given time t. |
|
|
4606 |
|
|
|
4607 |
<a id="mocointerpolantsbase_interpolantInterpolantstep"></a> |
|
|
4608 |
|
|
|
4609 |
#### step |
|
|
4610 |
|
|
|
4611 |
```python |
|
|
4612 |
@abstractmethod |
|
|
4613 |
def step(*args, **kwargs) -> Tensor |
|
|
4614 |
``` |
|
|
4615 |
|
|
|
4616 |
Do one step integration. |
|
|
4617 |
|
|
|
4618 |
<a id="mocointerpolantsbase_interpolantInterpolantgeneral_step"></a> |
|
|
4619 |
|
|
|
4620 |
#### general\_step |
|
|
4621 |
|
|
|
4622 |
```python |
|
|
4623 |
def general_step(method_name: str, kwargs: dict) |
|
|
4624 |
``` |
|
|
4625 |
|
|
|
4626 |
Calls a step method of the class by its name, passing the provided keyword arguments. |
|
|
4627 |
|
|
|
4628 |
**Arguments**: |
|
|
4629 |
|
|
|
4630 |
- `method_name` _str_ - The name of the step method to call. |
|
|
4631 |
- `kwargs` _dict_ - Keyword arguments to pass to the step method. |
|
|
4632 |
|
|
|
4633 |
|
|
|
4634 |
**Returns**: |
|
|
4635 |
|
|
|
4636 |
The result of the step method call. |
|
|
4637 |
|
|
|
4638 |
|
|
|
4639 |
**Raises**: |
|
|
4640 |
|
|
|
4641 |
- `ValueError` - If the provided method name does not start with 'step'. |
|
|
4642 |
- `Exception` - If the step method call fails. The error message includes a list of available step methods. |
|
|
4643 |
|
|
|
4644 |
|
|
|
4645 |
**Notes**: |
|
|
4646 |
|
|
|
4647 |
This method allows for dynamic invocation of step methods, providing flexibility in the class's usage. |
|
|
4648 |
|
|
|
4649 |
<a id="mocointerpolantsbase_interpolantInterpolantsample_prior"></a> |
|
|
4650 |
|
|
|
4651 |
#### sample\_prior |
|
|
4652 |
|
|
|
4653 |
```python |
|
|
4654 |
def sample_prior(*args, **kwargs) -> Tensor |
|
|
4655 |
``` |
|
|
4656 |
|
|
|
4657 |
Sample from prior distribution. |
|
|
4658 |
|
|
|
4659 |
This method generates a sample from the prior distribution specified by the |
|
|
4660 |
`prior_distribution` attribute. |
|
|
4661 |
|
|
|
4662 |
**Returns**: |
|
|
4663 |
|
|
|
4664 |
- `Tensor` - The generated sample from the prior distribution. |
|
|
4665 |
|
|
|
4666 |
<a id="mocointerpolantsbase_interpolantInterpolantsample_time"></a> |
|
|
4667 |
|
|
|
4668 |
#### sample\_time |
|
|
4669 |
|
|
|
4670 |
```python |
|
|
4671 |
def sample_time(*args, **kwargs) -> Tensor |
|
|
4672 |
``` |
|
|
4673 |
|
|
|
4674 |
Sample from time distribution. |
|
|
4675 |
|
|
|
4676 |
<a id="mocointerpolantsbase_interpolantInterpolantto_device"></a> |
|
|
4677 |
|
|
|
4678 |
#### to\_device |
|
|
4679 |
|
|
|
4680 |
```python |
|
|
4681 |
def to_device(device: str) |
|
|
4682 |
``` |
|
|
4683 |
|
|
|
4684 |
Moves all internal tensors to the specified device and updates the `self.device` attribute. |
|
|
4685 |
|
|
|
4686 |
**Arguments**: |
|
|
4687 |
|
|
|
4688 |
- `device` _str_ - The device to move the tensors to (e.g. "cpu", "cuda:0"). |
|
|
4689 |
|
|
|
4690 |
|
|
|
4691 |
**Notes**: |
|
|
4692 |
|
|
|
4693 |
This method is used to transfer the internal state of the DDPM interpolant to a different device. |
|
|
4694 |
It updates the `self.device` attribute to reflect the new device and moves all internal tensors to the specified device. |
|
|
4695 |
|
|
|
4696 |
<a id="mocointerpolantsbase_interpolantInterpolantclean_mask_center"></a> |
|
|
4697 |
|
|
|
4698 |
#### clean\_mask\_center |
|
|
4699 |
|
|
|
4700 |
```python |
|
|
4701 |
def clean_mask_center(data: Tensor, |
|
|
4702 |
mask: Optional[Tensor] = None, |
|
|
4703 |
center: Bool = False) -> Tensor |
|
|
4704 |
``` |
|
|
4705 |
|
|
|
4706 |
Returns a clean tensor that has been masked and/or centered based on the function arguments. |
|
|
4707 |
|
|
|
4708 |
**Arguments**: |
|
|
4709 |
|
|
|
4710 |
- `data` - The input data with shape (..., nodes, features). |
|
|
4711 |
- `mask` - An optional mask to apply to the data with shape (..., nodes). If provided, it is used to calculate the CoM. Defaults to None. |
|
|
4712 |
- `center` - A boolean indicating whether to center the data around the calculated CoM. Defaults to False. |
|
|
4713 |
|
|
|
4714 |
|
|
|
4715 |
**Returns**: |
|
|
4716 |
|
|
|
4717 |
The data with shape (..., nodes, features) either centered around the CoM if `center` is True or unchanged if `center` is False. |
|
|
4718 |
|
|
|
4719 |
<a id="mocotesting"></a> |
|
|
4720 |
|
|
|
4721 |
# bionemo.moco.testing |
|
|
4722 |
|
|
|
4723 |
<a id="mocotestingparallel_test_utils"></a> |
|
|
4724 |
|
|
|
4725 |
# bionemo.moco.testing.parallel\_test\_utils |
|
|
4726 |
|
|
|
4727 |
<a id="mocotestingparallel_test_utilsparallel_context"></a> |
|
|
4728 |
|
|
|
4729 |
#### parallel\_context |
|
|
4730 |
|
|
|
4731 |
```python |
|
|
4732 |
@contextmanager |
|
|
4733 |
def parallel_context(rank: int = 0, world_size: int = 1) |
|
|
4734 |
``` |
|
|
4735 |
|
|
|
4736 |
Context manager for torch distributed testing. |
|
|
4737 |
|
|
|
4738 |
Sets up and cleans up the distributed environment, including the device mesh. |
|
|
4739 |
|
|
|
4740 |
**Arguments**: |
|
|
4741 |
|
|
|
4742 |
- `rank` _int_ - The rank of the process. Defaults to 0. |
|
|
4743 |
- `world_size` _int_ - The world size of the distributed environment. Defaults to 1. |
|
|
4744 |
|
|
|
4745 |
|
|
|
4746 |
**Yields**: |
|
|
4747 |
|
|
|
4748 |
None |
|
|
4749 |
|
|
|
4750 |
<a id="mocotestingparallel_test_utilsclean_up_distributed"></a> |
|
|
4751 |
|
|
|
4752 |
#### clean\_up\_distributed |
|
|
4753 |
|
|
|
4754 |
```python |
|
|
4755 |
def clean_up_distributed() -> None |
|
|
4756 |
``` |
|
|
4757 |
|
|
|
4758 |
Cleans up the distributed environment. |
|
|
4759 |
|
|
|
4760 |
Destroys the process group and empties the CUDA cache. |
|
|
4761 |
|
|
|
4762 |
**Arguments**: |
|
|
4763 |
|
|
|
4764 |
None |
|
|
4765 |
|
|
|
4766 |
|
|
|
4767 |
**Returns**: |
|
|
4768 |
|
|
|
4769 |
None |