a b/sub-packages/bionemo-moco/documentation.md
1
# Table of Contents
2
3
* [moco](#moco)
4
* [bionemo.moco.distributions](#mocodistributions)
5
* [bionemo.moco.distributions.prior.distribution](#mocodistributionspriordistribution)
6
* [bionemo.moco.distributions.prior.discrete.uniform](#mocodistributionspriordiscreteuniform)
7
* [bionemo.moco.distributions.prior.discrete.custom](#mocodistributionspriordiscretecustom)
8
* [bionemo.moco.distributions.prior.discrete](#mocodistributionspriordiscrete)
9
* [bionemo.moco.distributions.prior.discrete.mask](#mocodistributionspriordiscretemask)
10
* [bionemo.moco.distributions.prior.continuous.harmonic](#mocodistributionspriorcontinuousharmonic)
11
* [bionemo.moco.distributions.prior.continuous](#mocodistributionspriorcontinuous)
12
* [bionemo.moco.distributions.prior.continuous.gaussian](#mocodistributionspriorcontinuousgaussian)
13
* [bionemo.moco.distributions.prior.continuous.utils](#mocodistributionspriorcontinuousutils)
14
* [bionemo.moco.distributions.prior](#mocodistributionsprior)
15
* [bionemo.moco.distributions.time.distribution](#mocodistributionstimedistribution)
16
* [bionemo.moco.distributions.time.uniform](#mocodistributionstimeuniform)
17
* [bionemo.moco.distributions.time.logit\_normal](#mocodistributionstimelogit_normal)
18
* [bionemo.moco.distributions.time](#mocodistributionstime)
19
* [bionemo.moco.distributions.time.beta](#mocodistributionstimebeta)
20
* [bionemo.moco.distributions.time.utils](#mocodistributionstimeutils)
21
* [bionemo.moco.schedules.noise.continuous\_snr\_transforms](#mocoschedulesnoisecontinuous_snr_transforms)
22
* [bionemo.moco.schedules.noise.discrete\_noise\_schedules](#mocoschedulesnoisediscrete_noise_schedules)
23
* [bionemo.moco.schedules.noise](#mocoschedulesnoise)
24
* [bionemo.moco.schedules.noise.continuous\_noise\_transforms](#mocoschedulesnoisecontinuous_noise_transforms)
25
* [bionemo.moco.schedules](#mocoschedules)
26
* [bionemo.moco.schedules.utils](#mocoschedulesutils)
27
* [bionemo.moco.schedules.inference\_time\_schedules](#mocoschedulesinference_time_schedules)
28
* [bionemo.moco.interpolants.continuous\_time.discrete](#mocointerpolantscontinuous_timediscrete)
29
* [bionemo.moco.interpolants.continuous\_time.discrete.mdlm](#mocointerpolantscontinuous_timediscretemdlm)
30
* [bionemo.moco.interpolants.continuous\_time.discrete.discrete\_flow\_matching](#mocointerpolantscontinuous_timediscretediscrete_flow_matching)
31
* [bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation.ot\_sampler](#mocointerpolantscontinuous_timecontinuousdata_augmentationot_sampler)
32
* [bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation.equivariant\_ot\_sampler](#mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_sampler)
33
* [bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation.kabsch\_augmentation](#mocointerpolantscontinuous_timecontinuousdata_augmentationkabsch_augmentation)
34
* [bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation](#mocointerpolantscontinuous_timecontinuousdata_augmentation)
35
* [bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation.augmentation\_types](#mocointerpolantscontinuous_timecontinuousdata_augmentationaugmentation_types)
36
* [bionemo.moco.interpolants.continuous\_time.continuous](#mocointerpolantscontinuous_timecontinuous)
37
* [bionemo.moco.interpolants.continuous\_time.continuous.vdm](#mocointerpolantscontinuous_timecontinuousvdm)
38
* [bionemo.moco.interpolants.continuous\_time.continuous.continuous\_flow\_matching](#mocointerpolantscontinuous_timecontinuouscontinuous_flow_matching)
39
* [bionemo.moco.interpolants.continuous\_time](#mocointerpolantscontinuous_time)
40
* [bionemo.moco.interpolants](#mocointerpolants)
41
* [bionemo.moco.interpolants.batch\_augmentation](#mocointerpolantsbatch_augmentation)
42
* [bionemo.moco.interpolants.discrete\_time.discrete.d3pm](#mocointerpolantsdiscrete_timediscreted3pm)
43
* [bionemo.moco.interpolants.discrete\_time.discrete](#mocointerpolantsdiscrete_timediscrete)
44
* [bionemo.moco.interpolants.discrete\_time.continuous.ddpm](#mocointerpolantsdiscrete_timecontinuousddpm)
45
* [bionemo.moco.interpolants.discrete\_time.continuous](#mocointerpolantsdiscrete_timecontinuous)
46
* [bionemo.moco.interpolants.discrete\_time](#mocointerpolantsdiscrete_time)
47
* [bionemo.moco.interpolants.discrete\_time.utils](#mocointerpolantsdiscrete_timeutils)
48
* [bionemo.moco.interpolants.base\_interpolant](#mocointerpolantsbase_interpolant)
49
* [bionemo.moco.testing](#mocotesting)
50
* [bionemo.moco.testing.parallel\_test\_utils](#mocotestingparallel_test_utils)
51
52
<a id="moco"></a>
53
54
# moco
55
56
<a id="mocodistributions"></a>
57
58
# bionemo.moco.distributions
59
60
<a id="mocodistributionspriordistribution"></a>
61
62
# bionemo.moco.distributions.prior.distribution
63
64
<a id="mocodistributionspriordistributionPriorDistribution"></a>
65
66
## PriorDistribution Objects
67
68
```python
69
class PriorDistribution(ABC)
70
```
71
72
An abstract base class representing a prior distribution.
73
74
<a id="mocodistributionspriordistributionPriorDistributionsample"></a>
75
76
#### sample
77
78
```python
79
@abstractmethod
80
def sample(shape: Tuple,
81
           mask: Optional[Tensor] = None,
82
           device: Union[str, torch.device] = "cpu") -> Tensor
83
```
84
85
Generates a specified number of samples from the time distribution.
86
87
**Arguments**:
88
89
- `shape` _Tuple_ - The shape of the samples to generate.
90
- `mask` _Optional[Tensor], optional_ - A tensor indicating which samples should be masked. Defaults to None.
91
- `device` _str, optional_ - The device on which to generate the samples. Defaults to "cpu".
92
93
94
**Returns**:
95
96
- `Float` - A tensor of samples.
97
98
<a id="mocodistributionspriordistributionDiscretePriorDistribution"></a>
99
100
## DiscretePriorDistribution Objects
101
102
```python
103
class DiscretePriorDistribution(PriorDistribution)
104
```
105
106
An abstract base class representing a discrete prior distribution.
107
108
<a id="mocodistributionspriordistributionDiscretePriorDistribution__init__"></a>
109
110
#### \_\_init\_\_
111
112
```python
113
def __init__(num_classes: int, prior_dist: Tensor)
114
```
115
116
Initializes a DiscretePriorDistribution instance.
117
118
**Arguments**:
119
120
- `num_classes` _int_ - The number of classes in the discrete distribution.
121
- `prior_dist` _Tensor_ - The prior distribution over the classes.
122
123
124
**Returns**:
125
126
  None
127
128
<a id="mocodistributionspriordistributionDiscretePriorDistributionget_num_classes"></a>
129
130
#### get\_num\_classes
131
132
```python
133
def get_num_classes() -> int
134
```
135
136
Getter for num_classes.
137
138
<a id="mocodistributionspriordistributionDiscretePriorDistributionget_prior_dist"></a>
139
140
#### get\_prior\_dist
141
142
```python
143
def get_prior_dist() -> Tensor
144
```
145
146
Getter for prior_dist.
147
148
<a id="mocodistributionspriordiscreteuniform"></a>
149
150
# bionemo.moco.distributions.prior.discrete.uniform
151
152
<a id="mocodistributionspriordiscreteuniformDiscreteUniformPrior"></a>
153
154
## DiscreteUniformPrior Objects
155
156
```python
157
class DiscreteUniformPrior(DiscretePriorDistribution)
158
```
159
160
A subclass representing a discrete uniform prior distribution.
161
162
<a id="mocodistributionspriordiscreteuniformDiscreteUniformPrior__init__"></a>
163
164
#### \_\_init\_\_
165
166
```python
167
def __init__(num_classes: int = 10) -> None
168
```
169
170
Initializes a discrete uniform prior distribution.
171
172
**Arguments**:
173
174
- `num_classes` _int_ - The number of classes in the discrete uniform distribution. Defaults to 10.
175
176
<a id="mocodistributionspriordiscreteuniformDiscreteUniformPriorsample"></a>
177
178
#### sample
179
180
```python
181
def sample(shape: Tuple,
182
           mask: Optional[Tensor] = None,
183
           device: Union[str, torch.device] = "cpu",
184
           rng_generator: Optional[torch.Generator] = None) -> Tensor
185
```
186
187
Generates a specified number of samples.
188
189
**Arguments**:
190
191
- `shape` _Tuple_ - The shape of the samples to generate.
192
- `device` _str_ - cpu or gpu.
193
- `mask` _Optional[Tensor]_ - An optional mask to apply to the samples. Defaults to None.
194
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
195
196
197
**Returns**:
198
199
- `Float` - A tensor of samples.
200
201
<a id="mocodistributionspriordiscretecustom"></a>
202
203
# bionemo.moco.distributions.prior.discrete.custom
204
205
<a id="mocodistributionspriordiscretecustomDiscreteCustomPrior"></a>
206
207
## DiscreteCustomPrior Objects
208
209
```python
210
class DiscreteCustomPrior(DiscretePriorDistribution)
211
```
212
213
A subclass representing a discrete custom prior distribution.
214
215
This class allows for the creation of a prior distribution with a custom
216
probability mass function defined by the `prior_dist` tensor. For example if my data has 4 classes and I want [.3, .2, .4, .1] as the probabilities of the 4 classes.
217
218
<a id="mocodistributionspriordiscretecustomDiscreteCustomPrior__init__"></a>
219
220
#### \_\_init\_\_
221
222
```python
223
def __init__(prior_dist: Tensor, num_classes: int = 10) -> None
224
```
225
226
Initializes a DiscreteCustomPrior distribution.
227
228
**Arguments**:
229
230
- `prior_dist` - A tensor representing the probability mass function of the prior distribution.
231
- `num_classes` - The number of classes in the prior distribution. Defaults to 10.
232
233
234
**Notes**:
235
236
  The `prior_dist` tensor should have a sum close to 1.0, as it represents a probability mass function.
237
238
<a id="mocodistributionspriordiscretecustomDiscreteCustomPriorsample"></a>
239
240
#### sample
241
242
```python
243
def sample(shape: Tuple,
244
           mask: Optional[Tensor] = None,
245
           device: Union[str, torch.device] = "cpu",
246
           rng_generator: Optional[torch.Generator] = None) -> Tensor
247
```
248
249
Samples from the discrete custom prior distribution.
250
251
**Arguments**:
252
253
- `shape` - A tuple specifying the shape of the samples to generate.
254
- `mask` - An optional tensor mask to apply to the samples, broadcastable to the sample shape. Defaults to None.
255
- `device` - The device on which to generate the samples, specified as a string or a :class:`torch.device`. Defaults to "cpu".
256
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
257
258
259
**Returns**:
260
261
  A tensor of samples drawn from the prior distribution.
262
263
<a id="mocodistributionspriordiscrete"></a>
264
265
# bionemo.moco.distributions.prior.discrete
266
267
<a id="mocodistributionspriordiscretemask"></a>
268
269
# bionemo.moco.distributions.prior.discrete.mask
270
271
<a id="mocodistributionspriordiscretemaskDiscreteMaskedPrior"></a>
272
273
## DiscreteMaskedPrior Objects
274
275
```python
276
class DiscreteMaskedPrior(DiscretePriorDistribution)
277
```
278
279
A subclass representing a Discrete Masked prior distribution.
280
281
<a id="mocodistributionspriordiscretemaskDiscreteMaskedPrior__init__"></a>
282
283
#### \_\_init\_\_
284
285
```python
286
def __init__(num_classes: int = 10,
287
             mask_dim: Optional[int] = None,
288
             inclusive: bool = True) -> None
289
```
290
291
Discrete Masked prior distribution.
292
293
Theres 3 ways I can think of defining the problem that are hard to mesh together.
294
295
1. [..., M, ....] inclusive anywhere --> exisiting LLM tokenizer where the mask has a specific location not at the end
296
2. [......, M] inclusive on end --> mask_dim = None with inclusive set to True default stick on the end
297
3. [.....] + [M] exclusive --> the number of classes representes the number of data classes and one wishes to add a separate MASK dimension.
298
- Note the pad_sample function is provided to help add this extra external dimension.
299
300
**Arguments**:
301
302
- `num_classes` _int_ - The number of classes in the distribution. Defaults to 10.
303
- `mask_dim` _int_ - The index for the mask token. Defaults to num_classes - 1 if inclusive or num_classes if exclusive.
304
- `inclusive` _bool_ - Whether the mask is included in the specified number of classes.
305
  If True, the mask is considered as one of the classes.
306
  If False, the mask is considered as an additional class. Defaults to True.
307
308
<a id="mocodistributionspriordiscretemaskDiscreteMaskedPriorsample"></a>
309
310
#### sample
311
312
```python
313
def sample(shape: Tuple,
314
           mask: Optional[Tensor] = None,
315
           device: Union[str, torch.device] = "cpu",
316
           rng_generator: Optional[torch.Generator] = None) -> Tensor
317
```
318
319
Generates a specified number of samples.
320
321
**Arguments**:
322
323
- `shape` _Tuple_ - The shape of the samples to generate.
324
- `device` _str_ - cpu or gpu.
325
- `mask` _Optional[Tensor]_ - An optional mask to apply to the samples. Defaults to None.
326
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
327
328
329
**Returns**:
330
331
- `Float` - A tensor of samples.
332
333
<a id="mocodistributionspriordiscretemaskDiscreteMaskedPrioris_masked"></a>
334
335
#### is\_masked
336
337
```python
338
def is_masked(sample: Tensor) -> Tensor
339
```
340
341
Creates a mask for whether a state is masked.
342
343
**Arguments**:
344
345
- `sample` _Tensor_ - The sample to check.
346
347
348
**Returns**:
349
350
- `Tensor` - A float tensor indicating whether the sample is masked.
351
352
<a id="mocodistributionspriordiscretemaskDiscreteMaskedPriorpad_sample"></a>
353
354
#### pad\_sample
355
356
```python
357
def pad_sample(sample: Tensor) -> Tensor
358
```
359
360
Pads the input sample with zeros along the last dimension.
361
362
**Arguments**:
363
364
- `sample` _Tensor_ - The input sample to be padded.
365
366
367
**Returns**:
368
369
- `Tensor` - The padded sample.
370
371
<a id="mocodistributionspriorcontinuousharmonic"></a>
372
373
# bionemo.moco.distributions.prior.continuous.harmonic
374
375
<a id="mocodistributionspriorcontinuousharmonicLinearHarmonicPrior"></a>
376
377
## LinearHarmonicPrior Objects
378
379
```python
380
class LinearHarmonicPrior(PriorDistribution)
381
```
382
383
A subclass representing a Linear Harmonic prior distribution from Jing et al. https://arxiv.org/abs/2304.02198.
384
385
<a id="mocodistributionspriorcontinuousharmonicLinearHarmonicPrior__init__"></a>
386
387
#### \_\_init\_\_
388
389
```python
390
def __init__(length: Optional[int] = None,
391
             distance: Float = 3.8,
392
             center: Bool = False,
393
             rng_generator: Optional[torch.Generator] = None,
394
             device: Union[str, torch.device] = "cpu") -> None
395
```
396
397
Linear Harmonic prior distribution.
398
399
**Arguments**:
400
401
- `length` _Optional[int]_ - The number of points in a batch.
402
- `distance` _Float_ - RMS distance between adjacent points in the line graph.
403
- `center` _bool_ - Whether to center the samples around the mean. Defaults to False.
404
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
405
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
406
407
<a id="mocodistributionspriorcontinuousharmonicLinearHarmonicPriorsample"></a>
408
409
#### sample
410
411
```python
412
def sample(shape: Tuple,
413
           mask: Optional[Tensor] = None,
414
           device: Union[str, torch.device] = "cpu",
415
           rng_generator: Optional[torch.Generator] = None) -> Tensor
416
```
417
418
Generates a specified number of samples from the Harmonic prior distribution.
419
420
**Arguments**:
421
422
- `shape` _Tuple_ - The shape of the samples to generate.
423
- `device` _str_ - cpu or gpu.
424
- `mask` _Optional[Tensor]_ - An optional mask to apply to the samples. Defaults to None.
425
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
426
427
428
**Returns**:
429
430
- `Float` - A tensor of samples.
431
432
<a id="mocodistributionspriorcontinuous"></a>
433
434
# bionemo.moco.distributions.prior.continuous
435
436
<a id="mocodistributionspriorcontinuousgaussian"></a>
437
438
# bionemo.moco.distributions.prior.continuous.gaussian
439
440
<a id="mocodistributionspriorcontinuousgaussianGaussianPrior"></a>
441
442
## GaussianPrior Objects
443
444
```python
445
class GaussianPrior(PriorDistribution)
446
```
447
448
A subclass representing a Gaussian prior distribution.
449
450
<a id="mocodistributionspriorcontinuousgaussianGaussianPrior__init__"></a>
451
452
#### \_\_init\_\_
453
454
```python
455
def __init__(mean: Float = 0.0,
456
             std: Float = 1.0,
457
             center: Bool = False,
458
             rng_generator: Optional[torch.Generator] = None) -> None
459
```
460
461
Gaussian prior distribution.
462
463
**Arguments**:
464
465
- `mean` _Float_ - The mean of the Gaussian distribution. Defaults to 0.0.
466
- `std` _Float_ - The standard deviation of the Gaussian distribution. Defaults to 1.0.
467
- `center` _bool_ - Whether to center the samples around the mean. Defaults to False.
468
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
469
470
<a id="mocodistributionspriorcontinuousgaussianGaussianPriorsample"></a>
471
472
#### sample
473
474
```python
475
def sample(shape: Tuple,
476
           mask: Optional[Tensor] = None,
477
           device: Union[str, torch.device] = "cpu",
478
           rng_generator: Optional[torch.Generator] = None) -> Tensor
479
```
480
481
Generates a specified number of samples from the Gaussian prior distribution.
482
483
**Arguments**:
484
485
- `shape` _Tuple_ - The shape of the samples to generate.
486
- `device` _str_ - cpu or gpu.
487
- `mask` _Optional[Tensor]_ - An optional mask to apply to the samples. Defaults to None.
488
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
489
490
491
**Returns**:
492
493
- `Float` - A tensor of samples.
494
495
<a id="mocodistributionspriorcontinuousutils"></a>
496
497
# bionemo.moco.distributions.prior.continuous.utils
498
499
<a id="mocodistributionspriorcontinuousutilsremove_center_of_mass"></a>
500
501
#### remove\_center\_of\_mass
502
503
```python
504
def remove_center_of_mass(data: Tensor,
505
                          mask: Optional[Tensor] = None) -> Tensor
506
```
507
508
Calculates the center of mass (CoM) of the given data.
509
510
**Arguments**:
511
512
- `data` - The input data with shape (..., nodes, features).
513
- `mask` - An optional binary mask to apply to the data with shape (..., nodes) to mask out interaction from CoM calculation. Defaults to None.
514
515
516
**Returns**:
517
518
  The CoM of the data with shape (..., 1, features).
519
520
<a id="mocodistributionsprior"></a>
521
522
# bionemo.moco.distributions.prior
523
524
<a id="mocodistributionstimedistribution"></a>
525
526
# bionemo.moco.distributions.time.distribution
527
528
<a id="mocodistributionstimedistributionTimeDistribution"></a>
529
530
## TimeDistribution Objects
531
532
```python
533
class TimeDistribution(ABC)
534
```
535
536
An abstract base class representing a time distribution.
537
538
**Arguments**:
539
540
- `discrete_time` _Bool_ - Whether the time is discrete.
541
- `nsteps` _Optional[int]_ - Number of nsteps for discretization.
542
- `min_t` _Optional[Float]_ - Min continuous time.
543
- `max_t` _Optional[Float]_ - Max continuous time.
544
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
545
546
<a id="mocodistributionstimedistributionTimeDistribution__init__"></a>
547
548
#### \_\_init\_\_
549
550
```python
551
def __init__(discrete_time: Bool = False,
552
             nsteps: Optional[int] = None,
553
             min_t: Optional[Float] = None,
554
             max_t: Optional[Float] = None,
555
             rng_generator: Optional[torch.Generator] = None)
556
```
557
558
Initializes a TimeDistribution object.
559
560
<a id="mocodistributionstimedistributionTimeDistributionsample"></a>
561
562
#### sample
563
564
```python
565
@abstractmethod
566
def sample(n_samples: Union[int, Tuple[int, ...], torch.Size],
567
           device: Union[str, torch.device] = "cpu",
568
           rng_generator: Optional[torch.Generator] = None) -> Float
569
```
570
571
Generates a specified number of samples from the time distribution.
572
573
**Arguments**:
574
575
- `n_samples` _int_ - The number of samples to generate.
576
- `device` _str_ - cpu or gpu.
577
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
578
579
580
**Returns**:
581
582
- `Float` - A list or array of samples.
583
584
<a id="mocodistributionstimedistributionMixTimeDistribution"></a>
585
586
## MixTimeDistribution Objects
587
588
```python
589
class MixTimeDistribution()
590
```
591
592
An abstract base class representing a mixed time distribution.
593
594
uniform_dist = UniformTimeDistribution(min_t=0.0, max_t=1.0, discrete_time=False)
595
beta_dist = BetaTimeDistribution(min_t=0.0, max_t=1.0, discrete_time=False, p1=2.0, p2=1.0)
596
mix_dist = MixTimeDistribution(uniform_dist, beta_dist, mix_fraction=0.5)
597
598
<a id="mocodistributionstimedistributionMixTimeDistribution__init__"></a>
599
600
#### \_\_init\_\_
601
602
```python
603
def __init__(dist1: TimeDistribution, dist2: TimeDistribution,
604
             mix_fraction: Float)
605
```
606
607
Initializes a MixTimeDistribution object.
608
609
**Arguments**:
610
611
- `dist1` _TimeDistribution_ - The first time distribution.
612
- `dist2` _TimeDistribution_ - The second time distribution.
613
- `mix_fraction` _Float_ - The fraction of samples to draw from dist1. Must be between 0 and 1.
614
615
<a id="mocodistributionstimedistributionMixTimeDistributionsample"></a>
616
617
#### sample
618
619
```python
620
def sample(n_samples: int,
621
           device: Union[str, torch.device] = "cpu",
622
           rng_generator: Optional[torch.Generator] = None) -> Float
623
```
624
625
Generates a specified number of samples from the mixed time distribution.
626
627
**Arguments**:
628
629
- `n_samples` _int_ - The number of samples to generate.
630
- `device` _str_ - cpu or gpu.
631
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
632
633
634
**Returns**:
635
636
- `Float` - A list or array of samples.
637
638
<a id="mocodistributionstimeuniform"></a>
639
640
# bionemo.moco.distributions.time.uniform
641
642
<a id="mocodistributionstimeuniformUniformTimeDistribution"></a>
643
644
## UniformTimeDistribution Objects
645
646
```python
647
class UniformTimeDistribution(TimeDistribution)
648
```
649
650
A class representing a uniform time distribution.
651
652
<a id="mocodistributionstimeuniformUniformTimeDistribution__init__"></a>
653
654
#### \_\_init\_\_
655
656
```python
657
def __init__(min_t: Float = 0.0,
658
             max_t: Float = 1.0,
659
             discrete_time: Bool = False,
660
             nsteps: Optional[int] = None,
661
             rng_generator: Optional[torch.Generator] = None)
662
```
663
664
Initializes a UniformTimeDistribution object.
665
666
**Arguments**:
667
668
- `min_t` _Float_ - The minimum time value.
669
- `max_t` _Float_ - The maximum time value.
670
- `discrete_time` _Bool_ - Whether the time is discrete.
671
- `nsteps` _Optional[int]_ - Number of nsteps for discretization.
672
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
673
674
<a id="mocodistributionstimeuniformUniformTimeDistributionsample"></a>
675
676
#### sample
677
678
```python
679
def sample(n_samples: Union[int, Tuple[int, ...], torch.Size],
680
           device: Union[str, torch.device] = "cpu",
681
           rng_generator: Optional[torch.Generator] = None)
682
```
683
684
Generates a specified number of samples from the uniform time distribution.
685
686
**Arguments**:
687
688
- `n_samples` _int_ - The number of samples to generate.
689
- `device` _str_ - cpu or gpu.
690
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
691
692
693
**Returns**:
694
695
  A tensor of samples.
696
697
<a id="mocodistributionstimeuniformSymmetricUniformTimeDistribution"></a>
698
699
## SymmetricUniformTimeDistribution Objects
700
701
```python
702
class SymmetricUniformTimeDistribution(TimeDistribution)
703
```
704
705
A class representing a uniform time distribution.
706
707
<a id="mocodistributionstimeuniformSymmetricUniformTimeDistribution__init__"></a>
708
709
#### \_\_init\_\_
710
711
```python
712
def __init__(min_t: Float = 0.0,
713
             max_t: Float = 1.0,
714
             discrete_time: Bool = False,
715
             nsteps: Optional[int] = None,
716
             rng_generator: Optional[torch.Generator] = None)
717
```
718
719
Initializes a UniformTimeDistribution object.
720
721
**Arguments**:
722
723
- `min_t` _Float_ - The minimum time value.
724
- `max_t` _Float_ - The maximum time value.
725
- `discrete_time` _Bool_ - Whether the time is discrete.
726
- `nsteps` _Optional[int]_ - Number of nsteps for discretization.
727
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
728
729
<a id="mocodistributionstimeuniformSymmetricUniformTimeDistributionsample"></a>
730
731
#### sample
732
733
```python
734
def sample(n_samples: Union[int, Tuple[int, ...], torch.Size],
735
           device: Union[str, torch.device] = "cpu",
736
           rng_generator: Optional[torch.Generator] = None)
737
```
738
739
Generates a specified number of samples from the uniform time distribution.
740
741
**Arguments**:
742
743
- `n_samples` _int_ - The number of samples to generate.
744
- `device` _str_ - cpu or gpu.
745
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
746
747
748
**Returns**:
749
750
  A tensor of samples.
751
752
<a id="mocodistributionstimelogit_normal"></a>
753
754
# bionemo.moco.distributions.time.logit\_normal
755
756
<a id="mocodistributionstimelogit_normalLogitNormalTimeDistribution"></a>
757
758
## LogitNormalTimeDistribution Objects
759
760
```python
761
class LogitNormalTimeDistribution(TimeDistribution)
762
```
763
764
A class representing a logit normal time distribution.
765
766
<a id="mocodistributionstimelogit_normalLogitNormalTimeDistribution__init__"></a>
767
768
#### \_\_init\_\_
769
770
```python
771
def __init__(p1: Float = 0.0,
772
             p2: Float = 1.0,
773
             min_t: Float = 0.0,
774
             max_t: Float = 1.0,
775
             discrete_time: Bool = False,
776
             nsteps: Optional[int] = None,
777
             rng_generator: Optional[torch.Generator] = None)
778
```
779
780
Initializes a BetaTimeDistribution object.
781
782
**Arguments**:
783
784
- `p1` _Float_ - The first shape parameter of the logit normal distribution i.e. the mean.
785
- `p2` _Float_ - The second shape parameter of the logit normal distribution i.e. the std.
786
- `min_t` _Float_ - The minimum time value.
787
- `max_t` _Float_ - The maximum time value.
788
- `discrete_time` _Bool_ - Whether the time is discrete.
789
- `nsteps` _Optional[int]_ - Number of nsteps for discretization.
790
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
791
792
<a id="mocodistributionstimelogit_normalLogitNormalTimeDistributionsample"></a>
793
794
#### sample
795
796
```python
797
def sample(n_samples: Union[int, Tuple[int, ...], torch.Size],
798
           device: Union[str, torch.device] = "cpu",
799
           rng_generator: Optional[torch.Generator] = None)
800
```
801
802
Generates a specified number of samples from the uniform time distribution.
803
804
**Arguments**:
805
806
- `n_samples` _int_ - The number of samples to generate.
807
- `device` _str_ - cpu or gpu.
808
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
809
810
811
**Returns**:
812
813
  A tensor of samples.
814
815
<a id="mocodistributionstime"></a>
816
817
# bionemo.moco.distributions.time
818
819
<a id="mocodistributionstimebeta"></a>
820
821
# bionemo.moco.distributions.time.beta
822
823
<a id="mocodistributionstimebetaBetaTimeDistribution"></a>
824
825
## BetaTimeDistribution Objects
826
827
```python
828
class BetaTimeDistribution(TimeDistribution)
829
```
830
831
A class representing a beta time distribution.
832
833
<a id="mocodistributionstimebetaBetaTimeDistribution__init__"></a>
834
835
#### \_\_init\_\_
836
837
```python
838
def __init__(p1: Float = 2.0,
839
             p2: Float = 1.0,
840
             min_t: Float = 0.0,
841
             max_t: Float = 1.0,
842
             discrete_time: Bool = False,
843
             nsteps: Optional[int] = None,
844
             rng_generator: Optional[torch.Generator] = None)
845
```
846
847
Initializes a BetaTimeDistribution object.
848
849
**Arguments**:
850
851
- `p1` _Float_ - The first shape parameter of the beta distribution.
852
- `p2` _Float_ - The second shape parameter of the beta distribution.
853
- `min_t` _Float_ - The minimum time value.
854
- `max_t` _Float_ - The maximum time value.
855
- `discrete_time` _Bool_ - Whether the time is discrete.
856
- `nsteps` _Optional[int]_ - Number of nsteps for discretization.
857
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
858
859
<a id="mocodistributionstimebetaBetaTimeDistributionsample"></a>
860
861
#### sample
862
863
```python
864
def sample(n_samples: Union[int, Tuple[int, ...], torch.Size],
865
           device: Union[str, torch.device] = "cpu",
866
           rng_generator: Optional[torch.Generator] = None)
867
```
868
869
Generates a specified number of samples from the uniform time distribution.
870
871
**Arguments**:
872
873
- `n_samples` _int_ - The number of samples to generate.
874
- `device` _str_ - cpu or gpu.
875
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
876
877
878
**Returns**:
879
880
  A tensor of samples.
881
882
<a id="mocodistributionstimeutils"></a>
883
884
# bionemo.moco.distributions.time.utils
885
886
<a id="mocodistributionstimeutilsfloat_time_to_index"></a>
887
888
#### float\_time\_to\_index
889
890
```python
891
def float_time_to_index(time: torch.Tensor,
892
                        num_time_steps: int) -> torch.Tensor
893
```
894
895
Convert a float time value to a time index.
896
897
**Arguments**:
898
899
- `time` _torch.Tensor_ - A tensor of float time values in the range [0, 1].
900
- `num_time_steps` _int_ - The number of discrete time steps.
901
902
903
**Returns**:
904
905
- `torch.Tensor` - A tensor of time indices corresponding to the input float time values.
906
907
<a id="mocoschedulesnoisecontinuous_snr_transforms"></a>
908
909
# bionemo.moco.schedules.noise.continuous\_snr\_transforms
910
911
<a id="mocoschedulesnoisecontinuous_snr_transformslog"></a>
912
913
#### log
914
915
```python
916
def log(t, eps=1e-20)
917
```
918
919
Compute the natural logarithm of a tensor, clamping values to avoid numerical instability.
920
921
**Arguments**:
922
923
- `t` _Tensor_ - The input tensor.
924
- `eps` _float, optional_ - The minimum value to clamp the input tensor (default is 1e-20).
925
926
927
**Returns**:
928
929
- `Tensor` - The natural logarithm of the input tensor.
930
931
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransform"></a>
932
933
## ContinuousSNRTransform Objects
934
935
```python
936
class ContinuousSNRTransform(ABC)
937
```
938
939
A base class for continuous SNR schedules.
940
941
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransform__init__"></a>
942
943
#### \_\_init\_\_
944
945
```python
946
def __init__(direction: TimeDirection)
947
```
948
949
Initialize the DiscreteNoiseSchedule.
950
951
**Arguments**:
952
953
- `direction` _TimeDirection_ - required this defines in which direction the scheduler was built
954
955
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransformcalculate_log_snr"></a>
956
957
#### calculate\_log\_snr
958
959
```python
960
def calculate_log_snr(t: Tensor,
961
                      device: Union[str, torch.device] = "cpu",
962
                      synchronize: Optional[TimeDirection] = None) -> Tensor
963
```
964
965
Public wrapper to generate the time schedule as a tensor.
966
967
**Arguments**:
968
969
- `t` _Tensor_ - The input tensor representing the time steps, with values ranging from 0 to 1.
970
- `device` _Optional[str]_ - The device to place the schedule on. Defaults to "cpu".
971
- `synchronize` _optional[TimeDirection]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction,
972
  this parameter allows to flip the direction to match the specified one. Defaults to None.
973
974
975
**Returns**:
976
977
- `Tensor` - A tensor representing the log signal-to-noise (SNR) ratio for the given time steps.
978
979
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransformlog_snr_to_alphas_sigmas"></a>
980
981
#### log\_snr\_to\_alphas\_sigmas
982
983
```python
984
def log_snr_to_alphas_sigmas(log_snr: Tensor) -> Tuple[Tensor, Tensor]
985
```
986
987
Converts log signal-to-noise ratio (SNR) to alpha and sigma values.
988
989
**Arguments**:
990
991
- `log_snr` _Tensor_ - The input log SNR tensor.
992
993
994
**Returns**:
995
996
  tuple[Tensor, Tensor]: A tuple containing the squared root of alpha and sigma values.
997
998
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransformderivative"></a>
999
1000
#### derivative
1001
1002
```python
1003
def derivative(t: Tensor, func: Callable) -> Tensor
1004
```
1005
1006
Compute derivative of a function, it supports bached single variable inputs.
1007
1008
**Arguments**:
1009
1010
- `t` _Tensor_ - time variable at which derivatives are taken
1011
- `func` _Callable_ - function for derivative calculation
1012
1013
1014
**Returns**:
1015
1016
- `Tensor` - derivative that is detached from the computational graph
1017
1018
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransformcalculate_general_sde_terms"></a>
1019
1020
#### calculate\_general\_sde\_terms
1021
1022
```python
1023
def calculate_general_sde_terms(t)
1024
```
1025
1026
Compute the general SDE terms for a given time step t.
1027
1028
**Arguments**:
1029
1030
- `t` _Tensor_ - The input tensor representing the time step.
1031
1032
1033
**Returns**:
1034
1035
  tuple[Tensor, Tensor]: A tuple containing the drift term f_t and the diffusion term g_t_2.
1036
1037
1038
**Notes**:
1039
1040
  This method computes the drift and diffusion terms of the general SDE, which can be used to simulate the stochastic process.
1041
  The drift term represents the deterministic part of the process, while the diffusion term represents the stochastic part.
1042
1043
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransformcalculate_beta"></a>
1044
1045
#### calculate\_beta
1046
1047
```python
1048
def calculate_beta(t)
1049
```
1050
1051
Compute the drift coefficient for the OU process of the form $dx = -\frac{1}{2} \beta(t) x dt + sqrt(beta(t)) dw_t$.
1052
1053
beta = d/dt log(alpha**2) = 2 * 1/alpha * d/dt(alpha)
1054
1055
**Arguments**:
1056
1057
- `t` _Union[float, Tensor]_ - t in [0, 1]
1058
1059
1060
**Returns**:
1061
1062
- `Tensor` - beta(t)
1063
1064
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransformcalculate_alpha_log_snr"></a>
1065
1066
#### calculate\_alpha\_log\_snr
1067
1068
```python
1069
def calculate_alpha_log_snr(log_snr: Tensor) -> Tensor
1070
```
1071
1072
Compute alpha values based on the log SNR.
1073
1074
**Arguments**:
1075
1076
- `log_snr` _Tensor_ - The input tensor representing the log signal-to-noise ratio.
1077
1078
1079
**Returns**:
1080
1081
- `Tensor` - A tensor representing the alpha values for the given log SNR.
1082
1083
1084
**Notes**:
1085
1086
  This method computes alpha values as the square root of the sigmoid of the log SNR.
1087
1088
<a id="mocoschedulesnoisecontinuous_snr_transformsContinuousSNRTransformcalculate_alpha_t"></a>
1089
1090
#### calculate\_alpha\_t
1091
1092
```python
1093
def calculate_alpha_t(t: Tensor) -> Tensor
1094
```
1095
1096
Compute alpha values based on the log SNR schedule.
1097
1098
**Arguments**:
1099
1100
- `t` _Tensor_ - The input tensor representing the time steps.
1101
1102
1103
**Returns**:
1104
1105
- `Tensor` - A tensor representing the alpha values for the given time steps.
1106
1107
1108
**Notes**:
1109
1110
  This method computes alpha values as the square root of the sigmoid of the log SNR.
1111
1112
<a id="mocoschedulesnoisecontinuous_snr_transformsCosineSNRTransform"></a>
1113
1114
## CosineSNRTransform Objects
1115
1116
```python
1117
class CosineSNRTransform(ContinuousSNRTransform)
1118
```
1119
1120
A cosine SNR schedule.
1121
1122
**Arguments**:
1123
1124
- `nu` _Optional[Float]_ - Hyperparameter for the cosine schedule exponent (default is 1.0).
1125
- `s` _Optional[Float]_ - Hyperparameter for the cosine schedule shift (default is 0.008).
1126
1127
<a id="mocoschedulesnoisecontinuous_snr_transformsCosineSNRTransform__init__"></a>
1128
1129
#### \_\_init\_\_
1130
1131
```python
1132
def __init__(nu: Float = 1.0, s: Float = 0.008)
1133
```
1134
1135
Initialize the CosineNoiseSchedule.
1136
1137
<a id="mocoschedulesnoisecontinuous_snr_transformsLinearSNRTransform"></a>
1138
1139
## LinearSNRTransform Objects
1140
1141
```python
1142
class LinearSNRTransform(ContinuousSNRTransform)
1143
```
1144
1145
A Linear SNR schedule.
1146
1147
<a id="mocoschedulesnoisecontinuous_snr_transformsLinearSNRTransform__init__"></a>
1148
1149
#### \_\_init\_\_
1150
1151
```python
1152
def __init__(min_value: Float = 1.0e-4)
1153
```
1154
1155
Initialize the Linear SNR Transform.
1156
1157
**Arguments**:
1158
1159
- `min_value` _Float_ - min vaue of SNR defaults to 1.e-4.
1160
1161
<a id="mocoschedulesnoisecontinuous_snr_transformsLinearLogInterpolatedSNRTransform"></a>
1162
1163
## LinearLogInterpolatedSNRTransform Objects
1164
1165
```python
1166
class LinearLogInterpolatedSNRTransform(ContinuousSNRTransform)
1167
```
1168
1169
A Linear Log space interpolated SNR schedule.
1170
1171
<a id="mocoschedulesnoisecontinuous_snr_transformsLinearLogInterpolatedSNRTransform__init__"></a>
1172
1173
#### \_\_init\_\_
1174
1175
```python
1176
def __init__(min_value: Float = -7.0, max_value=13.5)
1177
```
1178
1179
Initialize the Linear log space interpolated SNR Schedule from Chroma.
1180
1181
**Arguments**:
1182
1183
- `min_value` _Float_ - The min log SNR value.
1184
- `max_value` _Float_ - the max log SNR value.
1185
1186
<a id="mocoschedulesnoisediscrete_noise_schedules"></a>
1187
1188
# bionemo.moco.schedules.noise.discrete\_noise\_schedules
1189
1190
<a id="mocoschedulesnoisediscrete_noise_schedulesDiscreteNoiseSchedule"></a>
1191
1192
## DiscreteNoiseSchedule Objects
1193
1194
```python
1195
class DiscreteNoiseSchedule(ABC)
1196
```
1197
1198
A base class for discrete noise schedules.
1199
1200
<a id="mocoschedulesnoisediscrete_noise_schedulesDiscreteNoiseSchedule__init__"></a>
1201
1202
#### \_\_init\_\_
1203
1204
```python
1205
def __init__(nsteps: int, direction: TimeDirection)
1206
```
1207
1208
Initialize the DiscreteNoiseSchedule.
1209
1210
**Arguments**:
1211
1212
- `nsteps` _int_ - number of discrete steps.
1213
- `direction` _TimeDirection_ - required this defines in which direction the scheduler was built
1214
1215
<a id="mocoschedulesnoisediscrete_noise_schedulesDiscreteNoiseSchedulegenerate_schedule"></a>
1216
1217
#### generate\_schedule
1218
1219
```python
1220
def generate_schedule(nsteps: Optional[int] = None,
1221
                      device: Union[str, torch.device] = "cpu",
1222
                      synchronize: Optional[TimeDirection] = None) -> Tensor
1223
```
1224
1225
Generate the noise schedule as a tensor.
1226
1227
**Arguments**:
1228
1229
- `nsteps` _Optional[int]_ - Number of time steps. If None, uses the value from initialization.
1230
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
1231
- `synchronize` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction,
1232
  this parameter allows to flip the direction to match the specified one (default is None).
1233
1234
<a id="mocoschedulesnoisediscrete_noise_schedulesDiscreteNoiseSchedulecalculate_derivative"></a>
1235
1236
#### calculate\_derivative
1237
1238
```python
1239
def calculate_derivative(
1240
        nsteps: Optional[int] = None,
1241
        device: Union[str, torch.device] = "cpu",
1242
        synchronize: Optional[TimeDirection] = None) -> Tensor
1243
```
1244
1245
Calculate the time derivative of the schedule.
1246
1247
**Arguments**:
1248
1249
- `nsteps` _Optional[int]_ - Number of time steps. If None, uses the value from initialization.
1250
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
1251
- `synchronize` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction,
1252
  this parameter allows to flip the direction to match the specified one (default is None).
1253
1254
1255
**Returns**:
1256
1257
- `Tensor` - A tensor representing the time derivative of the schedule.
1258
1259
1260
**Raises**:
1261
1262
- `NotImplementedError` - If the derivative calculation is not implemented for this schedule.
1263
1264
<a id="mocoschedulesnoisediscrete_noise_schedulesDiscreteCosineNoiseSchedule"></a>
1265
1266
## DiscreteCosineNoiseSchedule Objects
1267
1268
```python
1269
class DiscreteCosineNoiseSchedule(DiscreteNoiseSchedule)
1270
```
1271
1272
A cosine discrete noise schedule.
1273
1274
<a id="mocoschedulesnoisediscrete_noise_schedulesDiscreteCosineNoiseSchedule__init__"></a>
1275
1276
#### \_\_init\_\_
1277
1278
```python
1279
def __init__(nsteps: int, nu: Float = 1.0, s: Float = 0.008)
1280
```
1281
1282
Initialize the CosineNoiseSchedule.
1283
1284
**Arguments**:
1285
1286
- `nsteps` _int_ - Number of discrete steps.
1287
- `nu` _Optional[Float]_ - Hyperparameter for the cosine schedule exponent (default is 1.0).
1288
- `s` _Optional[Float]_ - Hyperparameter for the cosine schedule shift (default is 0.008).
1289
1290
<a id="mocoschedulesnoisediscrete_noise_schedulesDiscreteLinearNoiseSchedule"></a>
1291
1292
## DiscreteLinearNoiseSchedule Objects
1293
1294
```python
1295
class DiscreteLinearNoiseSchedule(DiscreteNoiseSchedule)
1296
```
1297
1298
A linear discrete noise schedule.
1299
1300
<a id="mocoschedulesnoisediscrete_noise_schedulesDiscreteLinearNoiseSchedule__init__"></a>
1301
1302
#### \_\_init\_\_
1303
1304
```python
1305
def __init__(nsteps: int, beta_start: Float = 1e-4, beta_end: Float = 0.02)
1306
```
1307
1308
Initialize the CosineNoiseSchedule.
1309
1310
**Arguments**:
1311
1312
- `nsteps` _Optional[int]_ - Number of time steps. If None, uses the value from initialization.
1313
- `beta_start` _Optional[int]_ - starting beta value. Defaults to 1e-4.
1314
- `beta_end` _Optional[int]_ - end beta value. Defaults to 0.02.
1315
1316
<a id="mocoschedulesnoise"></a>
1317
1318
# bionemo.moco.schedules.noise
1319
1320
<a id="mocoschedulesnoisecontinuous_noise_transforms"></a>
1321
1322
# bionemo.moco.schedules.noise.continuous\_noise\_transforms
1323
1324
<a id="mocoschedulesnoisecontinuous_noise_transformsContinuousExpNoiseTransform"></a>
1325
1326
## ContinuousExpNoiseTransform Objects
1327
1328
```python
1329
class ContinuousExpNoiseTransform(ABC)
1330
```
1331
1332
A base class for continuous schedules.
1333
1334
alpha = exp(- sigma) where 1 - alpha controls the masking fraction.
1335
1336
<a id="mocoschedulesnoisecontinuous_noise_transformsContinuousExpNoiseTransform__init__"></a>
1337
1338
#### \_\_init\_\_
1339
1340
```python
1341
def __init__(direction: TimeDirection)
1342
```
1343
1344
Initialize the DiscreteNoiseSchedule.
1345
1346
**Arguments**:
1347
1348
  direction : TimeDirection, required this defines in which direction the scheduler was built
1349
1350
<a id="mocoschedulesnoisecontinuous_noise_transformsContinuousExpNoiseTransformcalculate_sigma"></a>
1351
1352
#### calculate\_sigma
1353
1354
```python
1355
def calculate_sigma(t: Tensor,
1356
                    device: Union[str, torch.device] = "cpu",
1357
                    synchronize: Optional[TimeDirection] = None) -> Tensor
1358
```
1359
1360
Calculate the sigma for the given time steps.
1361
1362
**Arguments**:
1363
1364
- `t` _Tensor_ - The input tensor representing the time steps, with values ranging from 0 to 1.
1365
- `device` _Optional[str]_ - The device to place the schedule on. Defaults to "cpu".
1366
- `synchronize` _optional[TimeDirection]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction,
1367
  this parameter allows to flip the direction to match the specified one. Defaults to None.
1368
1369
1370
**Returns**:
1371
1372
- `Tensor` - A tensor representing the sigma values for the given time steps.
1373
1374
1375
**Raises**:
1376
1377
- `ValueError` - If the input time steps exceed the maximum allowed value of 1.
1378
1379
<a id="mocoschedulesnoisecontinuous_noise_transformsContinuousExpNoiseTransformsigma_to_alpha"></a>
1380
1381
#### sigma\_to\_alpha
1382
1383
```python
1384
def sigma_to_alpha(sigma: Tensor) -> Tensor
1385
```
1386
1387
Converts sigma to alpha values by alpha = exp(- sigma).
1388
1389
**Arguments**:
1390
1391
- `sigma` _Tensor_ - The input sigma tensor.
1392
1393
1394
**Returns**:
1395
1396
- `Tensor` - A tensor containing the alpha values.
1397
1398
<a id="mocoschedulesnoisecontinuous_noise_transformsCosineExpNoiseTransform"></a>
1399
1400
## CosineExpNoiseTransform Objects
1401
1402
```python
1403
class CosineExpNoiseTransform(ContinuousExpNoiseTransform)
1404
```
1405
1406
A cosine Exponential noise schedule.
1407
1408
<a id="mocoschedulesnoisecontinuous_noise_transformsCosineExpNoiseTransform__init__"></a>
1409
1410
#### \_\_init\_\_
1411
1412
```python
1413
def __init__(eps: Float = 1.0e-3)
1414
```
1415
1416
Initialize the CosineNoiseSchedule.
1417
1418
**Arguments**:
1419
1420
- `eps` _Float_ - small number to prevent numerical issues.
1421
1422
<a id="mocoschedulesnoisecontinuous_noise_transformsCosineExpNoiseTransformd_dt_sigma"></a>
1423
1424
#### d\_dt\_sigma
1425
1426
```python
1427
def d_dt_sigma(t: Tensor, device: Union[str, torch.device] = "cpu") -> Tensor
1428
```
1429
1430
Compute the derivative of sigma with respect to time.
1431
1432
**Arguments**:
1433
1434
- `t` _Tensor_ - The input tensor representing the time steps.
1435
- `device` _Optional[str]_ - The device to place the schedule on. Defaults to "cpu".
1436
1437
1438
**Returns**:
1439
1440
- `Tensor` - A tensor representing the derivative of sigma with respect to time.
1441
1442
1443
**Notes**:
1444
1445
  The derivative of sigma as a function of time is given by:
1446
1447
  d/dt sigma(t) = d/dt (-log(cos(t * pi / 2) + eps))
1448
1449
  Using the chain rule, we get:
1450
1451
  d/dt sigma(t) = (-1 / (cos(t * pi / 2) + eps)) * (-sin(t * pi / 2) * pi / 2)
1452
1453
  This is the derivative that is computed and returned by this method.
1454
1455
<a id="mocoschedulesnoisecontinuous_noise_transformsLogLinearExpNoiseTransform"></a>
1456
1457
## LogLinearExpNoiseTransform Objects
1458
1459
```python
1460
class LogLinearExpNoiseTransform(ContinuousExpNoiseTransform)
1461
```
1462
1463
A log linear exponential schedule.
1464
1465
<a id="mocoschedulesnoisecontinuous_noise_transformsLogLinearExpNoiseTransform__init__"></a>
1466
1467
#### \_\_init\_\_
1468
1469
```python
1470
def __init__(eps: Float = 1.0e-3)
1471
```
1472
1473
Initialize the CosineNoiseSchedule.
1474
1475
**Arguments**:
1476
1477
- `eps` _Float_ - small value to prevent numerical issues.
1478
1479
<a id="mocoschedulesnoisecontinuous_noise_transformsLogLinearExpNoiseTransformd_dt_sigma"></a>
1480
1481
#### d\_dt\_sigma
1482
1483
```python
1484
def d_dt_sigma(t: Tensor, device: Union[str, torch.device] = "cpu") -> Tensor
1485
```
1486
1487
Compute the derivative of sigma with respect to time.
1488
1489
**Arguments**:
1490
1491
- `t` _Tensor_ - The input tensor representing the time steps.
1492
- `device` _Optional[str]_ - The device to place the schedule on. Defaults to "cpu".
1493
1494
1495
**Returns**:
1496
1497
- `Tensor` - A tensor representing the derivative of sigma with respect to time.
1498
1499
<a id="mocoschedules"></a>
1500
1501
# bionemo.moco.schedules
1502
1503
<a id="mocoschedulesutils"></a>
1504
1505
# bionemo.moco.schedules.utils
1506
1507
<a id="mocoschedulesutilsTimeDirection"></a>
1508
1509
## TimeDirection Objects
1510
1511
```python
1512
class TimeDirection(Enum)
1513
```
1514
1515
Enum for the direction of the noise schedule.
1516
1517
<a id="mocoschedulesutilsTimeDirectionUNIFIED"></a>
1518
1519
#### UNIFIED
1520
1521
Noise(0) --> Data(1)
1522
1523
<a id="mocoschedulesutilsTimeDirectionDIFFUSION"></a>
1524
1525
#### DIFFUSION
1526
1527
Noise(1) --> Data(0)
1528
1529
<a id="mocoschedulesinference_time_schedules"></a>
1530
1531
# bionemo.moco.schedules.inference\_time\_schedules
1532
1533
<a id="mocoschedulesinference_time_schedulesInferenceSchedule"></a>
1534
1535
## InferenceSchedule Objects
1536
1537
```python
1538
class InferenceSchedule(ABC)
1539
```
1540
1541
A base class for inference time schedules.
1542
1543
<a id="mocoschedulesinference_time_schedulesInferenceSchedule__init__"></a>
1544
1545
#### \_\_init\_\_
1546
1547
```python
1548
def __init__(nsteps: int,
1549
             min_t: Float = 0,
1550
             padding: Float = 0,
1551
             dilation: Float = 0,
1552
             direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
1553
             device: Union[str, torch.device] = "cpu")
1554
```
1555
1556
Initialize the InferenceSchedule.
1557
1558
**Arguments**:
1559
1560
- `nsteps` _int_ - Number of time steps.
1561
- `min_t` _Float_ - minimum time value defaults to 0.
1562
- `padding` _Float_ - padding time value defaults to 0.
1563
- `dilation` _Float_ - dilation time value defaults to 0 ie the number of replicates.
1564
- `direction` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).
1565
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
1566
1567
<a id="mocoschedulesinference_time_schedulesInferenceSchedulegenerate_schedule"></a>
1568
1569
#### generate\_schedule
1570
1571
```python
1572
@abstractmethod
1573
def generate_schedule(
1574
        nsteps: Optional[int] = None,
1575
        device: Optional[Union[str, torch.device]] = None) -> Tensor
1576
```
1577
1578
Generate the time schedule as a tensor.
1579
1580
**Arguments**:
1581
1582
- `nsteps` _Optioanl[int]_ - Number of time steps. If None, uses the value from initialization.
1583
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
1584
1585
<a id="mocoschedulesinference_time_schedulesInferenceSchedulepad_time"></a>
1586
1587
#### pad\_time
1588
1589
```python
1590
def pad_time(n_samples: int,
1591
             scalar_time: Float,
1592
             device: Optional[Union[str, torch.device]] = None) -> Tensor
1593
```
1594
1595
Creates a tensor of shape (n_samples,) filled with a scalar time value.
1596
1597
**Arguments**:
1598
1599
- `n_samples` _int_ - The desired dimension of the output tensor.
1600
- `scalar_time` _Float_ - The scalar time value to fill the tensor with.
1601
  device (Optional[Union[str, torch.device]], optional):
1602
  The device to place the tensor on. Defaults to None, which uses the default device.
1603
1604
1605
**Returns**:
1606
1607
- `Tensor` - A tensor of shape (n_samples,) filled with the scalar time value.
1608
1609
<a id="mocoschedulesinference_time_schedulesContinuousInferenceSchedule"></a>
1610
1611
## ContinuousInferenceSchedule Objects
1612
1613
```python
1614
class ContinuousInferenceSchedule(InferenceSchedule)
1615
```
1616
1617
A base class for continuous time inference schedules.
1618
1619
<a id="mocoschedulesinference_time_schedulesContinuousInferenceSchedule__init__"></a>
1620
1621
#### \_\_init\_\_
1622
1623
```python
1624
def __init__(nsteps: int,
1625
             inclusive_end: bool = False,
1626
             min_t: Float = 0,
1627
             padding: Float = 0,
1628
             dilation: Float = 0,
1629
             direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
1630
             device: Union[str, torch.device] = "cpu")
1631
```
1632
1633
Initialize the ContinuousInferenceSchedule.
1634
1635
**Arguments**:
1636
1637
- `nsteps` _int_ - Number of time steps.
1638
- `inclusive_end` _bool_ - If True, include the end value (1.0) in the schedule otherwise ends at 1.0-1/nsteps (default is False).
1639
- `min_t` _Float_ - minimum time value defaults to 0.
1640
- `padding` _Float_ - padding time value defaults to 0.
1641
- `dilation` _Float_ - dilation time value defaults to 0 ie the number of replicates.
1642
- `direction` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).
1643
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
1644
1645
<a id="mocoschedulesinference_time_schedulesContinuousInferenceSchedulediscretize"></a>
1646
1647
#### discretize
1648
1649
```python
1650
def discretize(nsteps: Optional[int] = None,
1651
               schedule: Optional[Tensor] = None,
1652
               device: Optional[Union[str, torch.device]] = None) -> Tensor
1653
```
1654
1655
Discretize the time schedule into a list of time deltas.
1656
1657
**Arguments**:
1658
1659
- `nsteps` _Optioanl[int]_ - Number of time steps. If None, uses the value from initialization.
1660
- `schedule` _Optional[Tensor]_ - Time scheudle if None will generate it with generate_schedule.
1661
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
1662
1663
1664
**Returns**:
1665
1666
- `Tensor` - A tensor of time deltas.
1667
1668
<a id="mocoschedulesinference_time_schedulesDiscreteInferenceSchedule"></a>
1669
1670
## DiscreteInferenceSchedule Objects
1671
1672
```python
1673
class DiscreteInferenceSchedule(InferenceSchedule)
1674
```
1675
1676
A base class for discrete time inference schedules.
1677
1678
<a id="mocoschedulesinference_time_schedulesDiscreteInferenceSchedulediscretize"></a>
1679
1680
#### discretize
1681
1682
```python
1683
def discretize(nsteps: Optional[int] = None,
1684
               device: Optional[Union[str, torch.device]] = None) -> Tensor
1685
```
1686
1687
Discretize the time schedule into a list of time deltas.
1688
1689
**Arguments**:
1690
1691
- `nsteps` _Optioanl[int]_ - Number of time steps. If None, uses the value from initialization.
1692
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
1693
1694
1695
**Returns**:
1696
1697
- `Tensor` - A tensor of time deltas.
1698
1699
<a id="mocoschedulesinference_time_schedulesDiscreteLinearInferenceSchedule"></a>
1700
1701
## DiscreteLinearInferenceSchedule Objects
1702
1703
```python
1704
class DiscreteLinearInferenceSchedule(DiscreteInferenceSchedule)
1705
```
1706
1707
A linear time schedule for discrete time inference.
1708
1709
<a id="mocoschedulesinference_time_schedulesDiscreteLinearInferenceSchedule__init__"></a>
1710
1711
#### \_\_init\_\_
1712
1713
```python
1714
def __init__(nsteps: int,
1715
             min_t: Float = 0,
1716
             padding: Float = 0,
1717
             dilation: Float = 0,
1718
             direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
1719
             device: Union[str, torch.device] = "cpu")
1720
```
1721
1722
Initialize the DiscreteLinearInferenceSchedule.
1723
1724
**Arguments**:
1725
1726
- `nsteps` _int_ - Number of time steps.
1727
- `min_t` _Float_ - minimum time value defaults to 0.
1728
- `padding` _Float_ - padding time value defaults to 0.
1729
- `dilation` _Float_ - dilation time value defaults to 0 ie the number of replicates.
1730
- `direction` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).
1731
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
1732
1733
<a id="mocoschedulesinference_time_schedulesDiscreteLinearInferenceSchedulegenerate_schedule"></a>
1734
1735
#### generate\_schedule
1736
1737
```python
1738
def generate_schedule(
1739
        nsteps: Optional[int] = None,
1740
        device: Optional[Union[str, torch.device]] = None) -> Tensor
1741
```
1742
1743
Generate the linear time schedule as a tensor.
1744
1745
**Arguments**:
1746
1747
- `nsteps` _Optional[int]_ - Number of time steps. If None uses the value from initialization.
1748
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
1749
1750
1751
**Returns**:
1752
1753
- `Tensor` - A tensor of time steps.
1754
- `Tensor` - A tensor of time steps.
1755
1756
<a id="mocoschedulesinference_time_schedulesLinearInferenceSchedule"></a>
1757
1758
## LinearInferenceSchedule Objects
1759
1760
```python
1761
class LinearInferenceSchedule(ContinuousInferenceSchedule)
1762
```
1763
1764
A linear time schedule for continuous time inference.
1765
1766
<a id="mocoschedulesinference_time_schedulesLinearInferenceSchedule__init__"></a>
1767
1768
#### \_\_init\_\_
1769
1770
```python
1771
def __init__(nsteps: int,
1772
             inclusive_end: bool = False,
1773
             min_t: Float = 0,
1774
             padding: Float = 0,
1775
             dilation: Float = 0,
1776
             direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
1777
             device: Union[str, torch.device] = "cpu")
1778
```
1779
1780
Initialize the LinearInferenceSchedule.
1781
1782
**Arguments**:
1783
1784
- `nsteps` _int_ - Number of time steps.
1785
- `inclusive_end` _bool_ - If True, include the end value (1.0) in the schedule otherwise ends at 1.0-1/nsteps (default is False).
1786
- `min_t` _Float_ - minimum time value defaults to 0.
1787
- `padding` _Float_ - padding time value defaults to 0.
1788
- `dilation` _Float_ - dilation time value defaults to 0 ie the number of replicates.
1789
- `direction` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).
1790
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
1791
1792
<a id="mocoschedulesinference_time_schedulesLinearInferenceSchedulegenerate_schedule"></a>
1793
1794
#### generate\_schedule
1795
1796
```python
1797
def generate_schedule(
1798
        nsteps: Optional[int] = None,
1799
        device: Optional[Union[str, torch.device]] = None) -> Tensor
1800
```
1801
1802
Generate the linear time schedule as a tensor.
1803
1804
**Arguments**:
1805
1806
- `nsteps` _Optional[int]_ - Number of time steps. If None uses the value from initialization.
1807
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
1808
1809
1810
**Returns**:
1811
1812
- `Tensor` - A tensor of time steps.
1813
1814
<a id="mocoschedulesinference_time_schedulesPowerInferenceSchedule"></a>
1815
1816
## PowerInferenceSchedule Objects
1817
1818
```python
1819
class PowerInferenceSchedule(ContinuousInferenceSchedule)
1820
```
1821
1822
A power time schedule for inference, where time steps are generated by raising a uniform schedule to a specified power.
1823
1824
<a id="mocoschedulesinference_time_schedulesPowerInferenceSchedule__init__"></a>
1825
1826
#### \_\_init\_\_
1827
1828
```python
1829
def __init__(nsteps: int,
1830
             inclusive_end: bool = False,
1831
             min_t: Float = 0,
1832
             padding: Float = 0,
1833
             dilation: Float = 0,
1834
             exponent: Float = 1.0,
1835
             direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
1836
             device: Union[str, torch.device] = "cpu")
1837
```
1838
1839
Initialize the PowerInferenceSchedule.
1840
1841
**Arguments**:
1842
1843
- `nsteps` _int_ - Number of time steps.
1844
- `inclusive_end` _bool_ - If True, include the end value (1.0) in the schedule otherwise ends at <1.0 (default is False).
1845
- `min_t` _Float_ - minimum time value defaults to 0.
1846
- `padding` _Float_ - padding time value defaults to 0.
1847
- `dilation` _Float_ - dilation time value defaults to 0 ie the number of replicates.
1848
- `exponent` _Float_ - Power parameter defaults to 1.0.
1849
- `direction` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).
1850
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
1851
1852
<a id="mocoschedulesinference_time_schedulesPowerInferenceSchedulegenerate_schedule"></a>
1853
1854
#### generate\_schedule
1855
1856
```python
1857
def generate_schedule(
1858
        nsteps: Optional[int] = None,
1859
        device: Optional[Union[str, torch.device]] = None) -> Tensor
1860
```
1861
1862
Generate the power time schedule as a tensor.
1863
1864
**Arguments**:
1865
1866
- `nsteps` _Optional[int]_ - Number of time steps. If None uses the value from initialization.
1867
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
1868
1869
1870
1871
**Returns**:
1872
1873
- `Tensor` - A tensor of time steps.
1874
- `Tensor` - A tensor of time steps.
1875
1876
<a id="mocoschedulesinference_time_schedulesLogInferenceSchedule"></a>
1877
1878
## LogInferenceSchedule Objects
1879
1880
```python
1881
class LogInferenceSchedule(ContinuousInferenceSchedule)
1882
```
1883
1884
A log time schedule for inference, where time steps are generated by taking the logarithm of a uniform schedule.
1885
1886
<a id="mocoschedulesinference_time_schedulesLogInferenceSchedule__init__"></a>
1887
1888
#### \_\_init\_\_
1889
1890
```python
1891
def __init__(nsteps: int,
1892
             inclusive_end: bool = False,
1893
             min_t: Float = 0,
1894
             padding: Float = 0,
1895
             dilation: Float = 0,
1896
             exponent: Float = -2.0,
1897
             direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
1898
             device: Union[str, torch.device] = "cpu")
1899
```
1900
1901
Initialize the LogInferenceSchedule.
1902
1903
Returns a log space time schedule.
1904
1905
Which for 100 steps with default parameters is:
1906
tensor([0.0000, 0.0455, 0.0889, 0.1303, 0.1699, 0.2077, 0.2439, 0.2783, 0.3113,
1907
0.3427, 0.3728, 0.4015, 0.4288, 0.4550, 0.4800, 0.5039, 0.5266, 0.5484,
1908
0.5692, 0.5890, 0.6080, 0.6261, 0.6434, 0.6599, 0.6756, 0.6907, 0.7051,
1909
0.7188, 0.7319, 0.7444, 0.7564, 0.7678, 0.7787, 0.7891, 0.7991, 0.8086,
1910
0.8176, 0.8263, 0.8346, 0.8425, 0.8500, 0.8572, 0.8641, 0.8707, 0.8769,
1911
0.8829, 0.8887, 0.8941, 0.8993, 0.9043, 0.9091, 0.9136, 0.9180, 0.9221,
1912
0.9261, 0.9299, 0.9335, 0.9369, 0.9402, 0.9434, 0.9464, 0.9492, 0.9520,
1913
0.9546, 0.9571, 0.9595, 0.9618, 0.9639, 0.9660, 0.9680, 0.9699, 0.9717,
1914
0.9734, 0.9751, 0.9767, 0.9782, 0.9796, 0.9810, 0.9823, 0.9835, 0.9847,
1915
0.9859, 0.9870, 0.9880, 0.9890, 0.9899, 0.9909, 0.9917, 0.9925, 0.9933,
1916
0.9941, 0.9948, 0.9955, 0.9962, 0.9968, 0.9974, 0.9980, 0.9985, 0.9990,
1917
0.9995])
1918
1919
**Arguments**:
1920
1921
- `nsteps` _int_ - Number of time steps.
1922
- `inclusive_end` _bool_ - If True, include the end value (1.0) in the schedule otherwise ends at <1.0 (default is False).
1923
- `min_t` _Float_ - minimum time value defaults to 0.
1924
- `padding` _Float_ - padding time value defaults to 0.
1925
- `dilation` _Float_ - dilation time value defaults to 0 ie the number of replicates.
1926
- `exponent` _Float_ - log space exponent parameter defaults to -2.0. The lower number the more aggressive the acceleration of 0 to 0.9 will be thus having more steps from 0.9 to 1.0.
1927
- `direction` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).
1928
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
1929
1930
<a id="mocoschedulesinference_time_schedulesLogInferenceSchedulegenerate_schedule"></a>
1931
1932
#### generate\_schedule
1933
1934
```python
1935
def generate_schedule(
1936
        nsteps: Optional[int] = None,
1937
        device: Optional[Union[str, torch.device]] = None) -> Tensor
1938
```
1939
1940
Generate the log time schedule as a tensor.
1941
1942
**Arguments**:
1943
1944
- `nsteps` _Optional[int]_ - Number of time steps. If None uses the value from initialization.
1945
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
1946
1947
<a id="mocointerpolantscontinuous_timediscrete"></a>
1948
1949
# bionemo.moco.interpolants.continuous\_time.discrete
1950
1951
<a id="mocointerpolantscontinuous_timediscretemdlm"></a>
1952
1953
# bionemo.moco.interpolants.continuous\_time.discrete.mdlm
1954
1955
<a id="mocointerpolantscontinuous_timediscretemdlmMDLM"></a>
1956
1957
## MDLM Objects
1958
1959
```python
1960
class MDLM(Interpolant)
1961
```
1962
1963
A Masked discrete Diffusion Language Model (MDLM) interpolant.
1964
1965
-------
1966
1967
**Examples**:
1968
1969
```python
1970
>>> import torch
1971
>>> from bionemo.bionemo.moco.distributions.prior.discrete.mask import DiscreteMaskedPrior
1972
>>> from bionemo.bionemo.moco.distributions.time.uniform import UniformTimeDistribution
1973
>>> from bionemo.bionemo.moco.interpolants.continuous_time.discrete.mdlm import MDLM
1974
>>> from bionemo.bionemo.moco.schedules.noise.continuous_noise_transforms import CosineExpNoiseTransform
1975
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import LinearTimeSchedule
1976
1977
1978
mdlm = MDLM(
1979
    time_distribution = UniformTimeDistribution(discrete_time = False,...),
1980
    prior_distribution = DiscreteMaskedPrior(...),
1981
    noise_schedule = CosineExpNoiseTransform(...),
1982
    )
1983
model = Model(...)
1984
1985
# Training
1986
for epoch in range(1000):
1987
    data = data_loader.get(...)
1988
    time = mdlm.sample_time(batch_size)
1989
    xt = mdlm.interpolate(data, time)
1990
1991
    logits = model(xt, time)
1992
    loss = mdlm.loss(logits, data, xt, time)
1993
    loss.backward()
1994
1995
# Generation
1996
x_pred = mdlm.sample_prior(data.shape)
1997
schedule = LinearTimeSchedule(...)
1998
inference_time = schedule.generate_schedule()
1999
dts = schedue.discreteize()
2000
for t, dt in zip(inference_time, dts):
2001
    time = torch.full((batch_size,), t)
2002
    logits = model(x_pred, time)
2003
    x_pred = mdlm.step(logits, time, x_pred, dt)
2004
return x_pred
2005
2006
```
2007
2008
<a id="mocointerpolantscontinuous_timediscretemdlmMDLM__init__"></a>
2009
2010
#### \_\_init\_\_
2011
2012
```python
2013
def __init__(time_distribution: TimeDistribution,
2014
             prior_distribution: DiscreteMaskedPrior,
2015
             noise_schedule: ContinuousExpNoiseTransform,
2016
             device: str = "cpu",
2017
             rng_generator: Optional[torch.Generator] = None)
2018
```
2019
2020
Initialize the Masked Discrete Language Model (MDLM) interpolant.
2021
2022
**Arguments**:
2023
2024
- `time_distribution` _TimeDistribution_ - The distribution governing the time variable in the diffusion process.
2025
- `prior_distribution` _DiscreteMaskedPrior_ - The prior distribution over the discrete token space, including masked tokens.
2026
- `noise_schedule` _ContinuousExpNoiseTransform_ - The noise schedule defining the noise intensity as a function of time.
2027
- `device` _str, optional_ - The device to use for computations. Defaults to "cpu".
2028
- `rng_generator` _Optional[torch.Generator], optional_ - The random number generator for reproducibility. Defaults to None.
2029
2030
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMinterpolate"></a>
2031
2032
#### interpolate
2033
2034
```python
2035
def interpolate(data: Tensor, t: Tensor)
2036
```
2037
2038
Get x(t) with given time t from noise and data.
2039
2040
**Arguments**:
2041
2042
- `data` _Tensor_ - target discrete ids
2043
- `t` _Tensor_ - time
2044
2045
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMforward_process"></a>
2046
2047
#### forward\_process
2048
2049
```python
2050
def forward_process(data: Tensor, t: Tensor) -> Tensor
2051
```
2052
2053
Apply the forward process to the data at time t.
2054
2055
**Arguments**:
2056
2057
- `data` _Tensor_ - target discrete ids
2058
- `t` _Tensor_ - time
2059
2060
2061
**Returns**:
2062
2063
- `Tensor` - x(t) after applying the forward process
2064
2065
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMloss"></a>
2066
2067
#### loss
2068
2069
```python
2070
def loss(logits: Tensor,
2071
         target: Tensor,
2072
         xt: Tensor,
2073
         time: Tensor,
2074
         mask: Optional[Tensor] = None,
2075
         use_weight=True)
2076
```
2077
2078
Calculate the cross-entropy loss between the model prediction and the target output.
2079
2080
The loss is calculated between the batch x node x class logits and the target batch x node,
2081
considering the current state of the discrete sequence `xt` at time `time`.
2082
2083
If `use_weight` is True, the loss is weighted by the reduced form of the MDLM time weight for continuous NELBO,
2084
as specified in equation 11 of https://arxiv.org/pdf/2406.07524. This weight is proportional to the derivative
2085
of the noise schedule with respect to time, and is used to emphasize the importance of accurate predictions at
2086
certain times in the diffusion process.
2087
2088
**Arguments**:
2089
2090
- `logits` _Tensor_ - The predicted output from the model, with shape batch x node x class.
2091
- `target` _Tensor_ - The target output for the model prediction, with shape batch x node.
2092
- `xt` _Tensor_ - The current state of the discrete sequence, with shape batch x node.
2093
- `time` _Tensor_ - The time at which the loss is calculated.
2094
- `mask` _Optional[Tensor], optional_ - The mask for the data point. Defaults to None.
2095
- `use_weight` _bool, optional_ - Whether to use the MDLM time weight for the loss. Defaults to True.
2096
2097
2098
**Returns**:
2099
2100
- `Tensor` - The calculated loss batch tensor.
2101
2102
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMstep"></a>
2103
2104
#### step
2105
2106
```python
2107
def step(logits: Tensor,
2108
         t: Tensor,
2109
         xt: Tensor,
2110
         dt: Tensor,
2111
         temperature: float = 1.0) -> Tensor
2112
```
2113
2114
Perform a single step of MDLM DDPM step.
2115
2116
**Arguments**:
2117
2118
- `logits` _Tensor_ - The input logits.
2119
- `t` _Tensor_ - The current time step.
2120
- `xt` _Tensor_ - The current state.
2121
- `dt` _Tensor_ - The time step increment.
2122
- `temperature` _float_ - Softmax temperature defaults to 1.0.
2123
2124
2125
**Returns**:
2126
2127
- `Tensor` - The updated state.
2128
2129
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMget_num_steps_confidence"></a>
2130
2131
#### get\_num\_steps\_confidence
2132
2133
```python
2134
def get_num_steps_confidence(xt: Tensor, num_tokens_unmask: int = 1)
2135
```
2136
2137
Calculate the maximum number of steps with confidence.
2138
2139
This method computes the maximum count of occurrences where the input tensor `xt` matches the `mask_index`
2140
along the last dimension (-1). The result is returned as a single float value.
2141
2142
**Arguments**:
2143
2144
- `xt` _Tensor_ - Input tensor to evaluate against the mask index.
2145
- `num_tokens_unmask` _int_ - number of tokens to unamsk at each step.
2146
2147
2148
**Returns**:
2149
2150
- `float` - The maximum number of steps with confidence (i.e., matching the mask index).
2151
2152
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMstep_confidence"></a>
2153
2154
#### step\_confidence
2155
2156
```python
2157
def step_confidence(logits: Tensor,
2158
                    xt: Tensor,
2159
                    curr_step: int,
2160
                    num_steps: int,
2161
                    logit_temperature: float = 1.0,
2162
                    randomness: float = 1.0,
2163
                    confidence_temperature: float = 1.0,
2164
                    num_tokens_unmask: int = 1) -> Tensor
2165
```
2166
2167
Update the input sequence xt by sampling from the predicted logits and adding Gumbel noise.
2168
2169
Method taken from GenMol Lee et al. https://arxiv.org/abs/2501.06158
2170
2171
**Arguments**:
2172
2173
- `logits` - Predicted logits
2174
- `xt` - Input sequence
2175
- `curr_step` - Current step
2176
- `num_steps` - Total number of steps
2177
- `logit_temperature` - Temperature for softmax over logits
2178
- `randomness` - Scale for Gumbel noise
2179
- `confidence_temperature` - Temperature for Gumbel confidence
2180
- `num_tokens_unmask` - number of tokens to unmask each step
2181
2182
2183
**Returns**:
2184
2185
  Updated input sequence xt unmasking num_tokens_unmask token each step.
2186
2187
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMstep_argmax"></a>
2188
2189
#### step\_argmax
2190
2191
```python
2192
def step_argmax(model_out: Tensor)
2193
```
2194
2195
Returns the index of the maximum value in the last dimension of the model output.
2196
2197
**Arguments**:
2198
2199
- `model_out` _Tensor_ - The output of the model.
2200
2201
2202
**Returns**:
2203
2204
- `Tensor` - The index of the maximum value in the last dimension of the model output.
2205
2206
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMcalculate_score"></a>
2207
2208
#### calculate\_score
2209
2210
```python
2211
def calculate_score(logits, x, t)
2212
```
2213
2214
Returns score of the given sample x at time t with the corresponding model output logits.
2215
2216
**Arguments**:
2217
2218
- `logits` _Tensor_ - The output of the model.
2219
- `x` _Tensor_ - The current data point.
2220
- `t` _Tensor_ - The current time.
2221
2222
2223
**Returns**:
2224
2225
- `Tensor` - The score defined in Appendix C.3 Equation 76 of MDLM.
2226
2227
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMstep_self_path_planning"></a>
2228
2229
#### step\_self\_path\_planning
2230
2231
```python
2232
def step_self_path_planning(logits: Tensor,
2233
                            xt: Tensor,
2234
                            t: Tensor,
2235
                            curr_step: int,
2236
                            num_steps: int,
2237
                            logit_temperature: float = 1.0,
2238
                            randomness: float = 1.0,
2239
                            confidence_temperature: float = 1.0,
2240
                            score_type: Literal["confidence",
2241
                                                "random"] = "confidence",
2242
                            fix_mask: Optional[Tensor] = None) -> Tensor
2243
```
2244
2245
Self Path Planning (P2) Sampling from Peng et al. https://arxiv.org/html/2502.03540v1.
2246
2247
**Arguments**:
2248
2249
- `logits` _Tensor_ - Predicted logits for sampling.
2250
- `xt` _Tensor_ - Input sequence to be updated.
2251
- `t` _Tensor_ - Time tensor (e.g., time steps or temporal info).
2252
- `curr_step` _int_ - Current iteration in the planning process.
2253
- `num_steps` _int_ - Total number of planning steps.
2254
- `logit_temperature` _float_ - Temperature for logits (default: 1.0).
2255
- `randomness` _float_ - Introduced randomness level (default: 1.0).
2256
- `confidence_temperature` _float_ - Temperature for confidence scoring (default: 1.0).
2257
- `score_type` _Literal["confidence", "random"]_ - Sampling score type (default: "confidence").
2258
- `fix_mask` _Optional[Tensor]_ - inital mask where True when not a mask tokens (default: None).
2259
2260
2261
**Returns**:
2262
2263
- `Tensor` - Updated input sequence xt after iterative unmasking.
2264
2265
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMtopk_lowest_masking"></a>
2266
2267
#### topk\_lowest\_masking
2268
2269
```python
2270
def topk_lowest_masking(scores: Tensor, cutoff_len: Tensor)
2271
```
2272
2273
Generates a mask for the lowest scoring elements up to a specified cutoff length.
2274
2275
**Arguments**:
2276
2277
- `scores` _Tensor_ - Input scores tensor with shape (... , num_elements)
2278
- `cutoff_len` _Tensor_ - Number of lowest-scoring elements to mask (per batch element)
2279
2280
2281
**Returns**:
2282
2283
- `Tensor` - Boolean mask tensor with same shape as `scores`, where `True` indicates
2284
  the corresponding element is among the `cutoff_len` lowest scores.
2285
2286
2287
**Example**:
2288
2289
  >>> scores = torch.tensor([[0.9, 0.8, 0.1, 0.05], [0.7, 0.4, 0.3, 0.2]])
2290
  >>> cutoff_len = 2
2291
  >>> mask = topk_lowest_masking(scores, cutoff_len)
2292
  >>> print(mask)
2293
  tensor([[False, False, True, True],
2294
  [False, True, True, False]])
2295
2296
<a id="mocointerpolantscontinuous_timediscretemdlmMDLMstochastic_sample_from_categorical"></a>
2297
2298
#### stochastic\_sample\_from\_categorical
2299
2300
```python
2301
def stochastic_sample_from_categorical(logits: Tensor,
2302
                                       temperature: float = 1.0,
2303
                                       noise_scale: float = 1.0)
2304
```
2305
2306
Stochastically samples from a categorical distribution defined by input logits, with optional temperature and noise scaling for diverse sampling.
2307
2308
**Arguments**:
2309
2310
- `logits` _Tensor_ - Input logits tensor with shape (... , num_categories)
2311
- `temperature` _float, optional_ - Softmax temperature. Higher values produce more uniform samples. Defaults to 1.0.
2312
- `noise_scale` _float, optional_ - Scale for Gumbel noise. Higher values produce more diverse samples. Defaults to 1.0.
2313
2314
2315
**Returns**:
2316
2317
  tuple:
2318
  - **tokens** (LongTensor): Sampling result (category indices) with shape (... , )
2319
  - **scores** (Tensor): Corresponding log-softmax scores for the sampled tokens, with shape (... , )
2320
2321
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matching"></a>
2322
2323
# bionemo.moco.interpolants.continuous\_time.discrete.discrete\_flow\_matching
2324
2325
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matchingDiscreteFlowMatcher"></a>
2326
2327
## DiscreteFlowMatcher Objects
2328
2329
```python
2330
class DiscreteFlowMatcher(Interpolant)
2331
```
2332
2333
A Discrete Flow Model (DFM) interpolant.
2334
2335
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matchingDiscreteFlowMatcher__init__"></a>
2336
2337
#### \_\_init\_\_
2338
2339
```python
2340
def __init__(time_distribution: TimeDistribution,
2341
             prior_distribution: DiscretePriorDistribution,
2342
             device: str = "cpu",
2343
             eps: Float = 1e-5,
2344
             rng_generator: Optional[torch.Generator] = None)
2345
```
2346
2347
Initialize the DFM interpolant.
2348
2349
**Arguments**:
2350
2351
- `time_distribution` _TimeDistribution_ - The time distribution for the diffusion process.
2352
- `prior_distribution` _DiscretePriorDistribution_ - The prior distribution for the discrete masked tokens.
2353
- `device` _str, optional_ - The device to use for computations. Defaults to "cpu".
2354
- `eps` - small Float to prevent dividing by zero.
2355
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
2356
2357
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matchingDiscreteFlowMatcherinterpolate"></a>
2358
2359
#### interpolate
2360
2361
```python
2362
def interpolate(data: Tensor, t: Tensor, noise: Tensor)
2363
```
2364
2365
Get x(t) with given time t from noise and data.
2366
2367
**Arguments**:
2368
2369
- `data` _Tensor_ - target discrete ids
2370
- `t` _Tensor_ - time
2371
- `noise` - tensor noise ids
2372
2373
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matchingDiscreteFlowMatcherloss"></a>
2374
2375
#### loss
2376
2377
```python
2378
def loss(logits: Tensor,
2379
         target: Tensor,
2380
         time: Optional[Tensor] = None,
2381
         mask: Optional[Tensor] = None,
2382
         use_weight: Bool = False)
2383
```
2384
2385
Calculate the cross-entropy loss between the model prediction and the target output.
2386
2387
The loss is calculated between the batch x node x class logits and the target batch x node.
2388
If using a masked prior please pass in the correct mask to calculate loss values on only masked states.
2389
i.e. mask = data_mask * is_masked_state which is calculated with self.prior_dist.is_masked(xt))
2390
2391
If `use_weight` is True, the loss is weighted by 1/(1-t) defined in equation 24 in Appndix C. of https://arxiv.org/pdf/2402.04997
2392
2393
**Arguments**:
2394
2395
- `logits` _Tensor_ - The predicted output from the model, with shape batch x node x class.
2396
- `target` _Tensor_ - The target output for the model prediction, with shape batch x node.
2397
- `time` _Tensor_ - The time at which the loss is calculated.
2398
- `mask` _Optional[Tensor], optional_ - The mask for the data point. Defaults to None.
2399
- `use_weight` _bool, optional_ - Whether to use the DFM time weight for the loss. Defaults to True.
2400
2401
2402
**Returns**:
2403
2404
- `Tensor` - The calculated loss batch tensor.
2405
2406
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matchingDiscreteFlowMatcherstep"></a>
2407
2408
#### step
2409
2410
```python
2411
def step(logits: Tensor,
2412
         t: Tensor,
2413
         xt: Tensor,
2414
         dt: Tensor | float,
2415
         temperature: Float = 1.0,
2416
         stochasticity: Float = 1.0) -> Tensor
2417
```
2418
2419
Perform a single step of DFM euler updates.
2420
2421
**Arguments**:
2422
2423
- `logits` _Tensor_ - The input logits.
2424
- `t` _Tensor_ - The current time step.
2425
- `xt` _Tensor_ - The current state.
2426
- `dt` _Tensor | float_ - The time step increment.
2427
- `temperature` _Float, optional_ - The temperature for the softmax calculation. Defaults to 1.0.
2428
- `stochasticity` _Float, optional_ - The stochasticity value for the step calculation. Defaults to 1.0.
2429
2430
2431
**Returns**:
2432
2433
- `Tensor` - The updated state.
2434
2435
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matchingDiscreteFlowMatcherstep_purity"></a>
2436
2437
#### step\_purity
2438
2439
```python
2440
def step_purity(logits: Tensor,
2441
                t: Tensor,
2442
                xt: Tensor,
2443
                dt: Tensor | float,
2444
                temperature: Float = 1.0,
2445
                stochasticity: Float = 1.0) -> Tensor
2446
```
2447
2448
Perform a single step of purity sampling.
2449
2450
https://github.com/jasonkyuyim/multiflow/blob/6278899970523bad29953047e7a42b32a41dc813/multiflow/data/interpolant.py#L346
2451
Here's a high-level overview of what the function does:
2452
TODO: check if the -1e9 and 1e-9 are small enough or using torch.inf would be better
2453
2454
1. Preprocessing:
2455
Checks if dt is a float and converts it to a tensor if necessary.
2456
Pads t and dt to match the shape of xt.
2457
Checks if the mask_index is valid (i.e., within the range of possible discrete values).
2458
2. Masking:
2459
Sets the logits corresponding to the mask_index to a low value (-1e9) to effectively mask out those values.
2460
Computes the softmax probabilities of the logits.
2461
Sets the probability of the mask_index to a small value (1e-9) to avoid numerical issues.
2462
3.Purity sampling:
2463
Computes the maximum log probabilities of the softmax distribution.
2464
Computes the indices of the top-number_to_unmask samples with the highest log probabilities.
2465
Uses these indices to sample new values from the original distribution.
2466
4. Unmasking and updating:
2467
Creates a mask to select the top-number_to_unmask samples.
2468
Uses this mask to update the current state xt with the new samples.
2469
5. Re-masking:
2470
Generates a new mask to randomly re-mask some of the updated samples.
2471
Applies this mask to the updated state xt.
2472
2473
**Arguments**:
2474
2475
- `logits` _Tensor_ - The input logits.
2476
- `t` _Tensor_ - The current time step.
2477
- `xt` _Tensor_ - The current state.
2478
- `dt` _Tensor_ - The time step increment.
2479
- `temperature` _Float, optional_ - The temperature for the softmax calculation. Defaults to 1.0.
2480
- `stochasticity` _Float, optional_ - The stochasticity value for the step calculation. Defaults to 1.0.
2481
2482
2483
**Returns**:
2484
2485
- `Tensor` - The updated state.
2486
2487
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matchingDiscreteFlowMatcherstep_argmax"></a>
2488
2489
#### step\_argmax
2490
2491
```python
2492
def step_argmax(model_out: Tensor)
2493
```
2494
2495
Returns the index of the maximum value in the last dimension of the model output.
2496
2497
**Arguments**:
2498
2499
- `model_out` _Tensor_ - The output of the model.
2500
2501
<a id="mocointerpolantscontinuous_timediscretediscrete_flow_matchingDiscreteFlowMatcherstep_simple_sample"></a>
2502
2503
#### step\_simple\_sample
2504
2505
```python
2506
def step_simple_sample(model_out: Tensor,
2507
                       temperature: float = 1.0,
2508
                       num_samples: int = 1)
2509
```
2510
2511
Samples from the model output logits. Leads to more diversity than step_argmax.
2512
2513
**Arguments**:
2514
2515
- `model_out` _Tensor_ - The output of the model.
2516
- `temperature` _Float, optional_ - The temperature for the softmax calculation. Defaults to 1.0.
2517
- `num_samples` _int_ - Number of samples to return
2518
2519
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationot_sampler"></a>
2520
2521
# bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation.ot\_sampler
2522
2523
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationot_samplerOTSampler"></a>
2524
2525
## OTSampler Objects
2526
2527
```python
2528
class OTSampler()
2529
```
2530
2531
Sampler for Exact Mini-batch Optimal Transport Plan.
2532
2533
OTSampler implements sampling coordinates according to an OT plan (wrt squared Euclidean cost)
2534
with different implementations of the plan calculation. Code is adapted from https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/optimal_transport.py
2535
2536
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationot_samplerOTSampler__init__"></a>
2537
2538
#### \_\_init\_\_
2539
2540
```python
2541
def __init__(method: str = "exact",
2542
             device: Union[str, torch.device] = "cpu",
2543
             num_threads: int = 1) -> None
2544
```
2545
2546
Initialize the OTSampler class.
2547
2548
**Arguments**:
2549
2550
- `method` _str_ - Choose which optimal transport solver you would like to use. Currently only support exact OT solvers (pot.emd).
2551
- `device` _Union[str, torch.device], optional_ - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
2552
- `num_threads` _Union[int, str], optional_ - Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1.
2553
2554
2555
**Raises**:
2556
2557
- `ValueError` - If the OT solver is not documented.
2558
- `NotImplementedError` - If the OT solver is not implemented.
2559
2560
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationot_samplerOTSamplerto_device"></a>
2561
2562
#### to\_device
2563
2564
```python
2565
def to_device(device: str)
2566
```
2567
2568
Moves all internal tensors to the specified device and updates the `self.device` attribute.
2569
2570
**Arguments**:
2571
2572
- `device` _str_ - The device to move the tensors to (e.g. "cpu", "cuda:0").
2573
2574
2575
**Notes**:
2576
2577
  This method is used to transfer the internal state of the OTSampler to a different device.
2578
  It updates the `self.device` attribute to reflect the new device and moves all internal tensors to the specified device.
2579
2580
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationot_samplerOTSamplersample_map"></a>
2581
2582
#### sample\_map
2583
2584
```python
2585
def sample_map(pi: Tensor,
2586
               batch_size: int,
2587
               replace: Bool = False) -> Tuple[Tensor, Tensor]
2588
```
2589
2590
Draw source and target samples from pi $(x,z) \sim \pi$.
2591
2592
**Arguments**:
2593
2594
- `pi` _Tensor_ - shape (bs, bs), the OT matrix between noise and data in minibatch.
2595
- `batch_size` _int_ - The batch size of the minibatch.
2596
- `replace` _bool_ - sampling w/ or w/o replacement from the OT plan, default to False.
2597
2598
2599
**Returns**:
2600
2601
- `Tuple` - tuple of 2 tensors, represents the indices of noise and data samples from pi.
2602
2603
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationot_samplerOTSamplerget_ot_matrix"></a>
2604
2605
#### get\_ot\_matrix
2606
2607
```python
2608
def get_ot_matrix(x0: Tensor,
2609
                  x1: Tensor,
2610
                  mask: Optional[Tensor] = None) -> Tensor
2611
```
2612
2613
Compute the OT matrix between a source and a target minibatch.
2614
2615
**Arguments**:
2616
2617
- `x0` _Tensor_ - shape (bs, *dim), noise from source minibatch.
2618
- `x1` _Tensor_ - shape (bs, *dim), data from source minibatch.
2619
- `mask` _Optional[Tensor], optional_ - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.
2620
2621
2622
**Returns**:
2623
2624
- `p` _Tensor_ - shape (bs, bs), the OT matrix between noise and data in minibatch.
2625
2626
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationot_samplerOTSamplerapply_augmentation"></a>
2627
2628
#### apply\_augmentation
2629
2630
```python
2631
def apply_augmentation(
2632
    x0: Tensor,
2633
    x1: Tensor,
2634
    mask: Optional[Tensor] = None,
2635
    replace: Bool = False,
2636
    sort: Optional[Literal["noise", "x0", "data", "x1"]] = "x0"
2637
) -> Tuple[Tensor, Tensor, Optional[Tensor]]
2638
```
2639
2640
Sample indices for noise and data in minibatch according to OT plan.
2641
2642
Compute the OT plan $\pi$ (wrt squared Euclidean cost) between a source and a target
2643
minibatch and draw source and target samples from pi $(x,z) \sim \pi$.
2644
2645
**Arguments**:
2646
2647
- `x0` _Tensor_ - shape (bs, *dim), noise from source minibatch.
2648
- `x1` _Tensor_ - shape (bs, *dim), data from source minibatch.
2649
- `mask` _Optional[Tensor], optional_ - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.
2650
- `replace` _bool_ - sampling w/ or w/o replacement from the OT plan, default to False.
2651
- `sort` _str_ - Optional Literal string to sort either x1 or x0 based on the input.
2652
2653
2654
**Returns**:
2655
2656
- `Tuple` - tuple of 2 tensors or 3 tensors if mask is used, represents the noise (plus mask) and data samples following OT plan pi.
2657
2658
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_sampler"></a>
2659
2660
# bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation.equivariant\_ot\_sampler
2661
2662
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_samplerEquivariantOTSampler"></a>
2663
2664
## EquivariantOTSampler Objects
2665
2666
```python
2667
class EquivariantOTSampler()
2668
```
2669
2670
Sampler for Mini-batch Optimal Transport Plan with cost calculated after Kabsch alignment.
2671
2672
EquivariantOTSampler implements sampling coordinates according to an OT plan
2673
(wrt squared Euclidean cost after Kabsch alignment) with different implementations of the plan calculation.
2674
2675
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_samplerEquivariantOTSampler__init__"></a>
2676
2677
#### \_\_init\_\_
2678
2679
```python
2680
def __init__(method: str = "exact",
2681
             device: Union[str, torch.device] = "cpu",
2682
             num_threads: int = 1) -> None
2683
```
2684
2685
Initialize the OTSampler class.
2686
2687
**Arguments**:
2688
2689
- `method` _str_ - Choose which optimal transport solver you would like to use. Currently only support exact OT solvers (pot.emd).
2690
- `device` _Union[str, torch.device], optional_ - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
2691
- `num_threads` _Union[int, str], optional_ - Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1.
2692
2693
2694
**Raises**:
2695
2696
- `ValueError` - If the OT solver is not documented.
2697
- `NotImplementedError` - If the OT solver is not implemented.
2698
2699
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_samplerEquivariantOTSamplerto_device"></a>
2700
2701
#### to\_device
2702
2703
```python
2704
def to_device(device: str)
2705
```
2706
2707
Moves all internal tensors to the specified device and updates the `self.device` attribute.
2708
2709
**Arguments**:
2710
2711
- `device` _str_ - The device to move the tensors to (e.g. "cpu", "cuda:0").
2712
2713
2714
**Notes**:
2715
2716
  This method is used to transfer the internal state of the OTSampler to a different device.
2717
  It updates the `self.device` attribute to reflect the new device and moves all internal tensors to the specified device.
2718
2719
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_samplerEquivariantOTSamplersample_map"></a>
2720
2721
#### sample\_map
2722
2723
```python
2724
def sample_map(pi: Tensor,
2725
               batch_size: int,
2726
               replace: Bool = False) -> Tuple[Tensor, Tensor]
2727
```
2728
2729
Draw source and target samples from pi $(x,z) \sim \pi$.
2730
2731
**Arguments**:
2732
2733
- `pi` _Tensor_ - shape (bs, bs), the OT matrix between noise and data in minibatch.
2734
- `batch_size` _int_ - The batch size of the minibatch.
2735
- `replace` _bool_ - sampling w/ or w/o replacement from the OT plan, default to False.
2736
2737
2738
**Returns**:
2739
2740
- `Tuple` - tuple of 2 tensors, represents the indices of noise and data samples from pi.
2741
2742
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_samplerEquivariantOTSamplerkabsch_align"></a>
2743
2744
#### kabsch\_align
2745
2746
```python
2747
def kabsch_align(target: Tensor, noise: Tensor) -> Tensor
2748
```
2749
2750
Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise.
2751
2752
**Arguments**:
2753
2754
- `target` _Tensor_ - shape (N, *dim), data from source minibatch.
2755
- `noise` _Tensor_ - shape (N, *dim), noise from source minibatch.
2756
2757
2758
**Returns**:
2759
2760
- `R` _Tensor_ - shape (*dim, *dim), the rotation matrix.
2761
2762
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_samplerEquivariantOTSamplerget_ot_matrix"></a>
2763
2764
#### get\_ot\_matrix
2765
2766
```python
2767
def get_ot_matrix(x0: Tensor,
2768
                  x1: Tensor,
2769
                  mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]
2770
```
2771
2772
Compute the OT matrix between a source and a target minibatch.
2773
2774
**Arguments**:
2775
2776
- `x0` _Tensor_ - shape (bs, *dim), noise from source minibatch.
2777
- `x1` _Tensor_ - shape (bs, *dim), data from source minibatch.
2778
- `mask` _Optional[Tensor], optional_ - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.
2779
2780
2781
**Returns**:
2782
2783
- `p` _Tensor_ - shape (bs, bs), the OT matrix between noise and data in minibatch.
2784
- `Rs` _Tensor_ - shape (bs, bs, *dim, *dim), the rotation matrix between noise and data in minibatch.
2785
2786
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationequivariant_ot_samplerEquivariantOTSamplerapply_augmentation"></a>
2787
2788
#### apply\_augmentation
2789
2790
```python
2791
def apply_augmentation(
2792
    x0: Tensor,
2793
    x1: Tensor,
2794
    mask: Optional[Tensor] = None,
2795
    replace: Bool = False,
2796
    sort: Optional[Literal["noise", "x0", "data", "x1"]] = "x0"
2797
) -> Tuple[Tensor, Tensor, Optional[Tensor]]
2798
```
2799
2800
Sample indices for noise and data in minibatch according to OT plan.
2801
2802
Compute the OT plan $\pi$ (wrt squared Euclidean cost after Kabsch alignment) between a source and a target
2803
minibatch and draw source and target samples from pi $(x,z) \sim \pi$.
2804
2805
**Arguments**:
2806
2807
- `x0` _Tensor_ - shape (bs, *dim), noise from source minibatch.
2808
- `x1` _Tensor_ - shape (bs, *dim), data from source minibatch.
2809
- `mask` _Optional[Tensor], optional_ - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.
2810
- `replace` _bool_ - sampling w/ or w/o replacement from the OT plan, default to False.
2811
- `sort` _str_ - Optional Literal string to sort either x1 or x0 based on the input.
2812
2813
2814
**Returns**:
2815
2816
- `Tuple` - tuple of 2 tensors, represents the noise and data samples following OT plan pi.
2817
2818
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationkabsch_augmentation"></a>
2819
2820
# bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation.kabsch\_augmentation
2821
2822
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationkabsch_augmentationKabschAugmentation"></a>
2823
2824
## KabschAugmentation Objects
2825
2826
```python
2827
class KabschAugmentation()
2828
```
2829
2830
Point-wise Kabsch alignment.
2831
2832
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationkabsch_augmentationKabschAugmentation__init__"></a>
2833
2834
#### \_\_init\_\_
2835
2836
```python
2837
def __init__()
2838
```
2839
2840
Initialize the KabschAugmentation instance.
2841
2842
**Notes**:
2843
2844
  - This implementation assumes no required initialization arguments.
2845
  - You can add instance variables (e.g., `self.variable_name`) as needed.
2846
2847
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationkabsch_augmentationKabschAugmentationkabsch_align"></a>
2848
2849
#### kabsch\_align
2850
2851
```python
2852
def kabsch_align(target: Tensor, noise: Tensor)
2853
```
2854
2855
Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise.
2856
2857
**Arguments**:
2858
2859
- `target` _Tensor_ - shape (N, *dim), data from source minibatch.
2860
- `noise` _Tensor_ - shape (N, *dim), noise from source minibatch.
2861
2862
2863
**Returns**:
2864
2865
- `R` _Tensor_ - shape (*dim, *dim), the rotation matrix.
2866
  Aliged Target (Tensor): target tensor rotated and shifted to reduced RMSD with noise
2867
2868
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationkabsch_augmentationKabschAugmentationbatch_kabsch_align"></a>
2869
2870
#### batch\_kabsch\_align
2871
2872
```python
2873
def batch_kabsch_align(target: Tensor, noise: Tensor)
2874
```
2875
2876
Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise.
2877
2878
**Arguments**:
2879
2880
- `target` _Tensor_ - shape (B, N, *dim), data from source minibatch.
2881
- `noise` _Tensor_ - shape (B, N, *dim), noise from source minibatch.
2882
2883
2884
**Returns**:
2885
2886
- `R` _Tensor_ - shape (*dim, *dim), the rotation matrix.
2887
  Aliged Target (Tensor): target tensor rotated and shifted to reduced RMSD with noise
2888
2889
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationkabsch_augmentationKabschAugmentationapply_augmentation"></a>
2890
2891
#### apply\_augmentation
2892
2893
```python
2894
def apply_augmentation(x0: Tensor,
2895
                       x1: Tensor,
2896
                       mask: Optional[Tensor] = None,
2897
                       align_noise_to_data=True) -> Tuple[Tensor, Tensor]
2898
```
2899
2900
Sample indices for noise and data in minibatch according to OT plan.
2901
2902
Compute the OT plan $\pi$ (wrt squared Euclidean cost after Kabsch alignment) between a source and a target
2903
minibatch and draw source and target samples from pi $(x,z) \sim \pi$.
2904
2905
**Arguments**:
2906
2907
- `x0` _Tensor_ - shape (bs, *dim), noise from source minibatch.
2908
- `x1` _Tensor_ - shape (bs, *dim), data from source minibatch.
2909
- `mask` _Optional[Tensor], optional_ - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.
2910
- `replace` _bool_ - sampling w/ or w/o replacement from the OT plan, default to False.
2911
- `align_noise_to_data` _bool_ - Direction of alignment default is True meaning it augments Noise to reduce error to Data.
2912
2913
2914
**Returns**:
2915
2916
- `Tuple` - tuple of 2 tensors, represents the noise and data samples following OT plan pi.
2917
2918
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentation"></a>
2919
2920
# bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation
2921
2922
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationaugmentation_types"></a>
2923
2924
# bionemo.moco.interpolants.continuous\_time.continuous.data\_augmentation.augmentation\_types
2925
2926
<a id="mocointerpolantscontinuous_timecontinuousdata_augmentationaugmentation_typesAugmentationType"></a>
2927
2928
## AugmentationType Objects
2929
2930
```python
2931
class AugmentationType(Enum)
2932
```
2933
2934
An enumeration representing the type ofOptimal Transport that can be used in Continuous Flow Matching.
2935
2936
- **EXACT_OT**: Standard mini batch optimal transport defined in  https://arxiv.org/pdf/2302.00482.
2937
- **EQUIVARIANT_OT**: Adding roto/translation optimization to mini batch OT see https://arxiv.org/pdf/2306.15030  https://arxiv.org/pdf/2312.07168 4.2.
2938
- **KABSCH**: Simple Kabsch alignment between each data and noise point, No permuation # https://arxiv.org/pdf/2410.22388 Sec 3.2
2939
2940
These prediction types can be used to train neural networks for specific tasks, such as denoising, image synthesis, or time-series forecasting.
2941
2942
<a id="mocointerpolantscontinuous_timecontinuous"></a>
2943
2944
# bionemo.moco.interpolants.continuous\_time.continuous
2945
2946
<a id="mocointerpolantscontinuous_timecontinuousvdm"></a>
2947
2948
# bionemo.moco.interpolants.continuous\_time.continuous.vdm
2949
2950
<a id="mocointerpolantscontinuous_timecontinuousvdmVDM"></a>
2951
2952
## VDM Objects
2953
2954
```python
2955
class VDM(Interpolant)
2956
```
2957
2958
A Variational Diffusion Models (VDM) interpolant.
2959
2960
-------
2961
2962
**Examples**:
2963
2964
```python
2965
>>> import torch
2966
>>> from bionemo.bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
2967
>>> from bionemo.bionemo.moco.distributions.time.uniform import UniformTimeDistribution
2968
>>> from bionemo.bionemo.moco.interpolants.discrete_time.continuous.vdm import VDM
2969
>>> from bionemo.bionemo.moco.schedules.noise.continuous_snr_transforms import CosineSNRTransform
2970
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule
2971
2972
2973
vdm = VDM(
2974
    time_distribution = UniformTimeDistribution(...),
2975
    prior_distribution = GaussianPrior(...),
2976
    noise_schedule = CosineSNRTransform(...),
2977
    )
2978
model = Model(...)
2979
2980
# Training
2981
for epoch in range(1000):
2982
    data = data_loader.get(...)
2983
    time = vdm.sample_time(batch_size)
2984
    noise = vdm.sample_prior(data.shape)
2985
    xt = vdm.interpolate(data, noise, time)
2986
2987
    x_pred = model(xt, time)
2988
    loss = vdm.loss(x_pred, data, time)
2989
    loss.backward()
2990
2991
# Generation
2992
x_pred = vdm.sample_prior(data.shape)
2993
for t in LinearInferenceSchedule(...).generate_schedule():
2994
    time = torch.full((batch_size,), t)
2995
    x_hat = model(x_pred, time)
2996
    x_pred = vdm.step(x_hat, time, x_pred)
2997
return x_pred
2998
2999
```
3000
3001
<a id="mocointerpolantscontinuous_timecontinuousvdmVDM__init__"></a>
3002
3003
#### \_\_init\_\_
3004
3005
```python
3006
def __init__(time_distribution: TimeDistribution,
3007
             prior_distribution: PriorDistribution,
3008
             noise_schedule: ContinuousSNRTransform,
3009
             prediction_type: Union[PredictionType, str] = PredictionType.DATA,
3010
             device: Union[str, torch.device] = "cpu",
3011
             rng_generator: Optional[torch.Generator] = None)
3012
```
3013
3014
Initializes the DDPM interpolant.
3015
3016
**Arguments**:
3017
3018
- `time_distribution` _TimeDistribution_ - The distribution of time steps, used to sample time points for the diffusion process.
3019
- `prior_distribution` _PriorDistribution_ - The prior distribution of the variable, used as the starting point for the diffusion process.
3020
- `noise_schedule` _ContinuousSNRTransform_ - The schedule of noise, defining the amount of noise added at each time step.
3021
- `prediction_type` _PredictionType, optional_ - The type of prediction, either "data" or another type. Defaults to "data".
3022
- `device` _str, optional_ - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
3023
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
3024
3025
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMinterpolate"></a>
3026
3027
#### interpolate
3028
3029
```python
3030
def interpolate(data: Tensor, t: Tensor, noise: Tensor)
3031
```
3032
3033
Get x(t) with given time t from noise and data.
3034
3035
**Arguments**:
3036
3037
- `data` _Tensor_ - target
3038
- `t` _Tensor_ - time
3039
- `noise` _Tensor_ - noise from prior()
3040
3041
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMforward_process"></a>
3042
3043
#### forward\_process
3044
3045
```python
3046
def forward_process(data: Tensor, t: Tensor, noise: Optional[Tensor] = None)
3047
```
3048
3049
Get x(t) with given time t from noise and data.
3050
3051
**Arguments**:
3052
3053
- `data` _Tensor_ - target
3054
- `t` _Tensor_ - time
3055
- `noise` _Tensor, optional_ - noise from prior(). Defaults to None
3056
3057
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMprocess_data_prediction"></a>
3058
3059
#### process\_data\_prediction
3060
3061
```python
3062
def process_data_prediction(model_output: Tensor, sample, t)
3063
```
3064
3065
Converts the model output to a data prediction based on the prediction type.
3066
3067
This conversion stems from the Progressive Distillation for Fast Sampling of Diffusion Models https://arxiv.org/pdf/2202.00512.
3068
Given the model output and the sample, we convert the output to a data prediction based on the prediction type.
3069
The conversion formulas are as follows:
3070
- For "noise" prediction type: `pred_data = (sample - noise_scale * model_output) / data_scale`
3071
- For "data" prediction type: `pred_data = model_output`
3072
- For "v_prediction" prediction type: `pred_data = data_scale * sample - noise_scale * model_output`
3073
3074
**Arguments**:
3075
3076
- `model_output` _Tensor_ - The output of the model.
3077
- `sample` _Tensor_ - The input sample.
3078
- `t` _Tensor_ - The time step.
3079
3080
3081
**Returns**:
3082
3083
  The data prediction based on the prediction type.
3084
3085
3086
**Raises**:
3087
3088
- `ValueError` - If the prediction type is not one of "noise", "data", or "v_prediction".
3089
3090
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMprocess_noise_prediction"></a>
3091
3092
#### process\_noise\_prediction
3093
3094
```python
3095
def process_noise_prediction(model_output: Tensor, sample: Tensor, t: Tensor)
3096
```
3097
3098
Do the same as process_data_prediction but take the model output and convert to nosie.
3099
3100
**Arguments**:
3101
3102
- `model_output` _Tensor_ - The output of the model.
3103
- `sample` _Tensor_ - The input sample.
3104
- `t` _Tensor_ - The time step.
3105
3106
3107
**Returns**:
3108
3109
  The input as noise if the prediction type is "noise".
3110
3111
3112
**Raises**:
3113
3114
- `ValueError` - If the prediction type is not "noise".
3115
3116
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMstep"></a>
3117
3118
#### step
3119
3120
```python
3121
def step(model_out: Tensor,
3122
         t: Tensor,
3123
         xt: Tensor,
3124
         dt: Tensor,
3125
         mask: Optional[Tensor] = None,
3126
         center: Bool = False,
3127
         temperature: Float = 1.0)
3128
```
3129
3130
Do one step integration.
3131
3132
**Arguments**:
3133
3134
- `model_out` _Tensor_ - The output of the model.
3135
- `xt` _Tensor_ - The current data point.
3136
- `t` _Tensor_ - The current time step.
3137
- `dt` _Tensor_ - The time step increment.
3138
- `mask` _Optional[Tensor], optional_ - An optional mask to apply to the data. Defaults to None.
3139
- `center` _bool_ - Whether to center the data. Defaults to False.
3140
- `temperature` _Float_ - The temperature parameter for low temperature sampling. Defaults to 1.0.
3141
3142
3143
**Notes**:
3144
3145
  The temperature parameter controls the trade off between diversity and sample quality.
3146
  Decreasing the temperature sharpens the sampling distribtion to focus on more likely samples.
3147
  The impact of low temperature sampling must be ablated analytically.
3148
3149
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMscore"></a>
3150
3151
#### score
3152
3153
```python
3154
def score(x_hat: Tensor, xt: Tensor, t: Tensor)
3155
```
3156
3157
Converts the data prediction to the estimated score function.
3158
3159
**Arguments**:
3160
3161
- `x_hat` _tensor_ - The predicted data point.
3162
- `xt` _Tensor_ - The current data point.
3163
- `t` _Tensor_ - The time step.
3164
3165
3166
**Returns**:
3167
3168
  The estimated score function.
3169
3170
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMstep_ddim"></a>
3171
3172
#### step\_ddim
3173
3174
```python
3175
def step_ddim(model_out: Tensor,
3176
              t: Tensor,
3177
              xt: Tensor,
3178
              dt: Tensor,
3179
              mask: Optional[Tensor] = None,
3180
              eta: Float = 0.0,
3181
              center: Bool = False)
3182
```
3183
3184
Do one step of DDIM sampling.
3185
3186
From the ddpm equations alpha_bar = alpha**2 and  1 - alpha**2 = sigma**2
3187
3188
**Arguments**:
3189
3190
- `model_out` _Tensor_ - output of the model
3191
- `t` _Tensor_ - current time step
3192
- `xt` _Tensor_ - current data point
3193
- `dt` _Tensor_ - The time step increment.
3194
- `mask` _Optional[Tensor], optional_ - mask for the data point. Defaults to None.
3195
- `eta` _Float, optional_ - DDIM sampling parameter. Defaults to 0.0.
3196
- `center` _Bool, optional_ - whether to center the data point. Defaults to False.
3197
3198
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMset_loss_weight_fn"></a>
3199
3200
#### set\_loss\_weight\_fn
3201
3202
```python
3203
def set_loss_weight_fn(fn: Callable)
3204
```
3205
3206
Sets the loss_weight attribute of the instance to the given function.
3207
3208
**Arguments**:
3209
3210
- `fn` - The function to set as the loss_weight attribute. This function should take three arguments: raw_loss, t, and weight_type.
3211
3212
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMloss_weight"></a>
3213
3214
#### loss\_weight
3215
3216
```python
3217
def loss_weight(raw_loss: Tensor,
3218
                t: Tensor,
3219
                weight_type: str,
3220
                dt: Float = 0.001) -> Tensor
3221
```
3222
3223
Calculates the weight for the loss based on the given weight type.
3224
3225
This function computes the loss weight according to the specified `weight_type`.
3226
The available weight types are:
3227
- "ones": uniform weight of 1.0
3228
- "data_to_noise": derived from Equation (9) of https://arxiv.org/pdf/2202.00512
3229
- "variational_objective_discrete": based on the variational objective, see https://arxiv.org/pdf/2202.00512
3230
3231
**Arguments**:
3232
3233
- `raw_loss` _Tensor_ - The raw loss calculated from the model prediction and target.
3234
- `t` _Tensor_ - The time step.
3235
- `weight_type` _str_ - The type of weight to use. Can be "ones", "data_to_noise", or "variational_objective_discrete".
3236
- `dt` _Float, optional_ - The time step increment. Defaults to 0.001.
3237
3238
3239
**Returns**:
3240
3241
- `Tensor` - The weight for the loss.
3242
3243
3244
**Raises**:
3245
3246
- `ValueError` - If the weight type is not recognized.
3247
3248
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMloss"></a>
3249
3250
#### loss
3251
3252
```python
3253
def loss(model_pred: Tensor,
3254
         target: Tensor,
3255
         t: Tensor,
3256
         dt: Optional[Float] = 0.001,
3257
         mask: Optional[Tensor] = None,
3258
         weight_type: str = "ones")
3259
```
3260
3261
Calculates the loss given the model prediction, target, and time.
3262
3263
**Arguments**:
3264
3265
- `model_pred` _Tensor_ - The predicted output from the model.
3266
- `target` _Tensor_ - The target output for the model prediction.
3267
- `t` _Tensor_ - The time at which the loss is calculated.
3268
- `dt` _Optional[Float], optional_ - The time step increment. Defaults to 0.001.
3269
- `mask` _Optional[Tensor], optional_ - The mask for the data point. Defaults to None.
3270
- `weight_type` _str, optional_ - The type of weight to use for the loss. Can be "ones", "data_to_noise", or "variational_objective". Defaults to "ones".
3271
3272
3273
**Returns**:
3274
3275
- `Tensor` - The calculated loss batch tensor.
3276
3277
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMstep_hybrid_sde"></a>
3278
3279
#### step\_hybrid\_sde
3280
3281
```python
3282
def step_hybrid_sde(model_out: Tensor,
3283
                    t: Tensor,
3284
                    xt: Tensor,
3285
                    dt: Tensor,
3286
                    mask: Optional[Tensor] = None,
3287
                    center: Bool = False,
3288
                    temperature: Float = 1.0,
3289
                    equilibrium_rate: Float = 0.0) -> Tensor
3290
```
3291
3292
Do one step integration of Hybrid Langevin-Reverse Time SDE.
3293
3294
See section B.3 page 37 https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf.
3295
and https://github.com/generatebio/chroma/blob/929407c605013613941803c6113adefdccaad679/chroma/layers/structure/diffusion.py#L730
3296
3297
**Arguments**:
3298
3299
- `model_out` _Tensor_ - The output of the model.
3300
- `xt` _Tensor_ - The current data point.
3301
- `t` _Tensor_ - The current time step.
3302
- `dt` _Tensor_ - The time step increment.
3303
- `mask` _Optional[Tensor], optional_ - An optional mask to apply to the data. Defaults to None.
3304
- `center` _bool, optional_ - Whether to center the data. Defaults to False.
3305
- `temperature` _Float, optional_ - The temperature parameter for low temperature sampling. Defaults to 1.0.
3306
- `equilibrium_rate` _Float, optional_ - The rate of Langevin equilibration.  Scales the amount of Langevin dynamics per unit time. Best values are in the range [1.0, 5.0]. Defaults to 0.0.
3307
3308
3309
**Notes**:
3310
3311
  For all step functions that use the SDE formulation its important to note that we are moving backwards in time which corresponds to an apparent sign change.
3312
  A clear example can be seen in slide 29 https://ernestryu.com/courses/FM/diffusion1.pdf.
3313
3314
<a id="mocointerpolantscontinuous_timecontinuousvdmVDMstep_ode"></a>
3315
3316
#### step\_ode
3317
3318
```python
3319
def step_ode(model_out: Tensor,
3320
             t: Tensor,
3321
             xt: Tensor,
3322
             dt: Tensor,
3323
             mask: Optional[Tensor] = None,
3324
             center: Bool = False,
3325
             temperature: Float = 1.0) -> Tensor
3326
```
3327
3328
Do one step integration of ODE.
3329
3330
See section B page 36 https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf.
3331
and https://github.com/generatebio/chroma/blob/929407c605013613941803c6113adefdccaad679/chroma/layers/structure/diffusion.py#L730
3332
3333
**Arguments**:
3334
3335
- `model_out` _Tensor_ - The output of the model.
3336
- `xt` _Tensor_ - The current data point.
3337
- `t` _Tensor_ - The current time step.
3338
- `dt` _Tensor_ - The time step increment.
3339
- `mask` _Optional[Tensor], optional_ - An optional mask to apply to the data. Defaults to None.
3340
- `center` _bool, optional_ - Whether to center the data. Defaults to False.
3341
- `temperature` _Float, optional_ - The temperature parameter for low temperature sampling. Defaults to 1.0.
3342
3343
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matching"></a>
3344
3345
# bionemo.moco.interpolants.continuous\_time.continuous.continuous\_flow\_matching
3346
3347
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcher"></a>
3348
3349
## ContinuousFlowMatcher Objects
3350
3351
```python
3352
class ContinuousFlowMatcher(Interpolant)
3353
```
3354
3355
A Continuous Flow Matching interpolant.
3356
3357
-------
3358
3359
**Examples**:
3360
3361
```python
3362
>>> import torch
3363
>>> from bionemo.bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
3364
>>> from bionemo.bionemo.moco.distributions.time.uniform import UniformTimeDistribution
3365
>>> from bionemo.bionemo.moco.interpolants.continuous_time.continuous.continuous_flow_matching import ContinuousFlowMatcher
3366
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule
3367
3368
flow_matcher = ContinuousFlowMatcher(
3369
    time_distribution = UniformTimeDistribution(...),
3370
    prior_distribution = GaussianPrior(...),
3371
    )
3372
model = Model(...)
3373
3374
# Training
3375
for epoch in range(1000):
3376
    data = data_loader.get(...)
3377
    time = flow_matcher.sample_time(batch_size)
3378
    noise = flow_matcher.sample_prior(data.shape)
3379
    data, time, noise = flow_matcher.apply_augmentation(noise, data) # Optional, only for OT
3380
    xt = flow_matcher.interpolate(data, time, noise)
3381
    flow = flow_matcher.calculate_target(data, noise)
3382
3383
    u_pred = model(xt, time)
3384
    loss = flow_matcher.loss(u_pred, flow)
3385
    loss.backward()
3386
3387
# Generation
3388
x_pred = flow_matcher.sample_prior(data.shape)
3389
inference_sched = LinearInferenceSchedule(...)
3390
for t in inference_sched.generate_schedule():
3391
    time = inference_sched.pad_time(x_pred.shape[0], t)
3392
    u_hat = model(x_pred, time)
3393
    x_pred = flow_matcher.step(u_hat, x_pred, time)
3394
return x_pred
3395
3396
```
3397
3398
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcher__init__"></a>
3399
3400
#### \_\_init\_\_
3401
3402
```python
3403
def __init__(time_distribution: TimeDistribution,
3404
             prior_distribution: PriorDistribution,
3405
             prediction_type: Union[PredictionType, str] = PredictionType.DATA,
3406
             sigma: Float = 0,
3407
             augmentation_type: Optional[Union[AugmentationType, str]] = None,
3408
             augmentation_num_threads: int = 1,
3409
             data_scale: Float = 1.0,
3410
             device: Union[str, torch.device] = "cpu",
3411
             rng_generator: Optional[torch.Generator] = None,
3412
             eps: Float = 1e-5)
3413
```
3414
3415
Initializes the Continuous Flow Matching interpolant.
3416
3417
**Arguments**:
3418
3419
- `time_distribution` _TimeDistribution_ - The distribution of time steps, used to sample time points for the diffusion process.
3420
- `prior_distribution` _PriorDistribution_ - The prior distribution of the variable, used as the starting point for the diffusion process.
3421
- `prediction_type` _PredictionType, optional_ - The type of prediction, either "flow" or another type. Defaults to PredictionType.DATA.
3422
- `sigma` _Float, optional_ - The standard deviation of the Gaussian noise added to the interpolated data. Defaults to 0.
3423
- `augmentation_type` _Optional[Union[AugmentationType, str]], optional_ - The type of optimal transport, if applicable. Defaults to None.
3424
- `augmentation_num_threads` - Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1.
3425
- `data_scale` _Float, optional_ - The scale factor for the data. Defaults to 1.0.
3426
- `device` _Union[str, torch.device], optional_ - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
3427
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
3428
- `eps` - Small float to prevent divide by zero
3429
3430
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherapply_augmentation"></a>
3431
3432
#### apply\_augmentation
3433
3434
```python
3435
def apply_augmentation(x0: Tensor,
3436
                       x1: Tensor,
3437
                       mask: Optional[Tensor] = None,
3438
                       **kwargs) -> tuple
3439
```
3440
3441
Sample and apply the optimal transport plan between batched (and masked) x0 and x1.
3442
3443
**Arguments**:
3444
3445
- `x0` _Tensor_ - shape (bs, *dim), noise from source minibatch.
3446
- `x1` _Tensor_ - shape (bs, *dim), data from source minibatch.
3447
- `mask` _Optional[Tensor], optional_ - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.
3448
- `**kwargs` - Additional keyword arguments to be passed to self.augmentation_sampler.apply_augmentation or handled within this method.
3449
3450
3451
3452
**Returns**:
3453
3454
- `Tuple` - tuple of 2 tensors, represents the noise and data samples following OT plan pi.
3455
3456
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherundo_scale_data"></a>
3457
3458
#### undo\_scale\_data
3459
3460
```python
3461
def undo_scale_data(data: Tensor) -> Tensor
3462
```
3463
3464
Downscale the input data by the data scale factor.
3465
3466
**Arguments**:
3467
3468
- `data` _Tensor_ - The input data to downscale.
3469
3470
3471
**Returns**:
3472
3473
  The downscaled data.
3474
3475
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherscale_data"></a>
3476
3477
#### scale\_data
3478
3479
```python
3480
def scale_data(data: Tensor) -> Tensor
3481
```
3482
3483
Upscale the input data by the data scale factor.
3484
3485
**Arguments**:
3486
3487
- `data` _Tensor_ - The input data to upscale.
3488
3489
3490
**Returns**:
3491
3492
  The upscaled data.
3493
3494
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherinterpolate"></a>
3495
3496
#### interpolate
3497
3498
```python
3499
def interpolate(data: Tensor, t: Tensor, noise: Tensor) -> Tensor
3500
```
3501
3502
Get x_t with given time t from noise (x_0) and data (x_1).
3503
3504
Currently, we use the linear interpolation as defined in:
3505
1. Rectified flow: https://arxiv.org/abs/2209.03003.
3506
2. Conditional flow matching: https://arxiv.org/abs/2210.02747 (called conditional optimal transport).
3507
3508
**Arguments**:
3509
3510
- `noise` _Tensor_ - noise from prior(), shape (batchsize, nodes, features)
3511
- `t` _Tensor_ - time, shape (batchsize)
3512
- `data` _Tensor_ - target, shape (batchsize, nodes, features)
3513
3514
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatchercalculate_target"></a>
3515
3516
#### calculate\_target
3517
3518
```python
3519
def calculate_target(data: Tensor,
3520
                     noise: Tensor,
3521
                     mask: Optional[Tensor] = None) -> Tensor
3522
```
3523
3524
Get the target vector field at time t.
3525
3526
**Arguments**:
3527
3528
- `noise` _Tensor_ - noise from prior(), shape (batchsize, nodes, features)
3529
- `data` _Tensor_ - target, shape (batchsize, nodes, features)
3530
- `mask` _Optional[Tensor], optional_ - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.
3531
3532
3533
**Returns**:
3534
3535
- `Tensor` - The target vector field at time t.
3536
3537
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherprocess_vector_field_prediction"></a>
3538
3539
#### process\_vector\_field\_prediction
3540
3541
```python
3542
def process_vector_field_prediction(model_output: Tensor,
3543
                                    xt: Optional[Tensor] = None,
3544
                                    t: Optional[Tensor] = None,
3545
                                    mask: Optional[Tensor] = None)
3546
```
3547
3548
Process the model output based on the prediction type to calculate vecotr field.
3549
3550
**Arguments**:
3551
3552
- `model_output` _Tensor_ - The output of the model.
3553
- `xt` _Tensor_ - The input sample.
3554
- `t` _Tensor_ - The time step.
3555
- `mask` _Optional[Tensor], optional_ - An optional mask to apply to the model output. Defaults to None.
3556
3557
3558
**Returns**:
3559
3560
  The vector field prediction based on the prediction type.
3561
3562
3563
**Raises**:
3564
3565
- `ValueError` - If the prediction type is not "flow" or "data".
3566
3567
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherprocess_data_prediction"></a>
3568
3569
#### process\_data\_prediction
3570
3571
```python
3572
def process_data_prediction(model_output: Tensor,
3573
                            xt: Optional[Tensor] = None,
3574
                            t: Optional[Tensor] = None,
3575
                            mask: Optional[Tensor] = None)
3576
```
3577
3578
Process the model output based on the prediction type to generate clean data.
3579
3580
**Arguments**:
3581
3582
- `model_output` _Tensor_ - The output of the model.
3583
- `xt` _Tensor_ - The input sample.
3584
- `t` _Tensor_ - The time step.
3585
- `mask` _Optional[Tensor], optional_ - An optional mask to apply to the model output. Defaults to None.
3586
3587
3588
**Returns**:
3589
3590
  The data prediction based on the prediction type.
3591
3592
3593
**Raises**:
3594
3595
- `ValueError` - If the prediction type is not "flow".
3596
3597
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherstep"></a>
3598
3599
#### step
3600
3601
```python
3602
def step(model_out: Tensor,
3603
         xt: Tensor,
3604
         dt: Tensor,
3605
         t: Optional[Tensor] = None,
3606
         mask: Optional[Tensor] = None,
3607
         center: Bool = False)
3608
```
3609
3610
Perform a single ODE step integration using Euler method.
3611
3612
**Arguments**:
3613
3614
- `model_out` _Tensor_ - The output of the model at the current time step.
3615
- `xt` _Tensor_ - The current intermediate state.
3616
- `dt` _Tensor_ - The time step size.
3617
- `t` _Tensor, optional_ - The current time. Defaults to None.
3618
- `mask` _Optional[Tensor], optional_ - A mask to apply to the model output. Defaults to None.
3619
- `center` _Bool, optional_ - Whether to center the output. Defaults to False.
3620
3621
3622
**Returns**:
3623
3624
- `x_next` _Tensor_ - The updated state of the system after the single step, x_(t+dt).
3625
3626
3627
**Notes**:
3628
3629
  - If a mask is provided, it is applied element-wise to the model output before scaling.
3630
  - The `clean` method is called on the updated state before it is returned.
3631
3632
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherstep_score_stochastic"></a>
3633
3634
#### step\_score\_stochastic
3635
3636
```python
3637
def step_score_stochastic(model_out: Tensor,
3638
                          xt: Tensor,
3639
                          dt: Tensor,
3640
                          t: Tensor,
3641
                          mask: Optional[Tensor] = None,
3642
                          gt_mode: str = "tan",
3643
                          gt_p: Float = 1.0,
3644
                          gt_clamp: Optional[Float] = None,
3645
                          score_temperature: Float = 1.0,
3646
                          noise_temperature: Float = 1.0,
3647
                          t_lim_ode: Float = 0.99,
3648
                          center: Bool = False)
3649
```
3650
3651
Perform a single SDE step integration using a score-based Langevin update.
3652
3653
d x_t = [v(x_t, t) + g(t) * s(x_t, t) * score_temperature] dt + \sqrt{2 * g(t) * noise_temperature} dw_t.
3654
3655
**Arguments**:
3656
3657
- `model_out` _Tensor_ - The output of the model at the current time step.
3658
- `xt` _Tensor_ - The current intermediate state.
3659
- `dt` _Tensor_ - The time step size.
3660
- `t` _Tensor, optional_ - The current time. Defaults to None.
3661
- `mask` _Optional[Tensor], optional_ - A mask to apply to the model output. Defaults to None.
3662
- `gt_mode` _str, optional_ - The mode for the gt function. Defaults to "tan".
3663
- `gt_p` _Float, optional_ - The parameter for the gt function. Defaults to 1.0.
3664
- `gt_clamp` - (Float, optional): Upper limit of gt term. Defaults to None.
3665
- `score_temperature` _Float, optional_ - The temperature for the score part of the step. Defaults to 1.0.
3666
- `noise_temperature` _Float, optional_ - The temperature for the stochastic part of the step. Defaults to 1.0.
3667
- `t_lim_ode` _Float, optional_ - The time limit for the ODE step. Defaults to 0.99.
3668
- `center` _Bool, optional_ - Whether to center the output. Defaults to False.
3669
3670
3671
**Returns**:
3672
3673
- `x_next` _Tensor_ - The updated state of the system after the single step, x_(t+dt).
3674
3675
3676
**Notes**:
3677
3678
  - If a mask is provided, it is applied element-wise to the model output before scaling.
3679
  - The `clean` method is called on the updated state before it is returned.
3680
3681
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherloss"></a>
3682
3683
#### loss
3684
3685
```python
3686
def loss(model_pred: Tensor,
3687
         target: Tensor,
3688
         t: Optional[Tensor] = None,
3689
         xt: Optional[Tensor] = None,
3690
         mask: Optional[Tensor] = None,
3691
         target_type: Union[PredictionType, str] = PredictionType.DATA)
3692
```
3693
3694
Calculate the loss given the model prediction, data sample, time, and mask.
3695
3696
If target_type is FLOW loss = ||v_hat - (x1-x0)||**2
3697
If target_type is DATA loss = ||x1_hat - x1||**2 * 1 / (1 - t)**2 as the target vector field = x1 - x0 = (1/(1-t)) * x1 - xt where xt = tx1 - (1-t)x0.
3698
This functions supports any cominbation of prediction_type and target_type in {DATA, FLOW}.
3699
3700
**Arguments**:
3701
3702
- `model_pred` _Tensor_ - The predicted output from the model.
3703
- `target` _Tensor_ - The target output for the model prediction.
3704
- `t` _Optional[Tensor], optional_ - The time for the model prediction. Defaults to None.
3705
- `xt` _Optional[Tensor], optional_ - The interpolated data. Defaults to None.
3706
- `mask` _Optional[Tensor], optional_ - The mask for the data point. Defaults to None.
3707
- `target_type` _PredictionType, optional_ - The type of the target output. Defaults to PredictionType.DATA.
3708
3709
3710
**Returns**:
3711
3712
- `Tensor` - The calculated loss batch tensor.
3713
3714
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatchervf_to_score"></a>
3715
3716
#### vf\_to\_score
3717
3718
```python
3719
def vf_to_score(x_t: Tensor, v: Tensor, t: Tensor) -> Tensor
3720
```
3721
3722
From Geffner et al. Computes score of noisy density given the vector field learned by flow matching.
3723
3724
With our interpolation scheme these are related by
3725
3726
v(x_t, t) = (1 / t) (x_t + scale_ref ** 2 * (1 - t) * s(x_t, t)),
3727
3728
or equivalently,
3729
3730
s(x_t, t) = (t * v(x_t, t) - x_t) / (scale_ref ** 2 * (1 - t)).
3731
3732
with scale_ref = 1
3733
3734
**Arguments**:
3735
3736
- `x_t` - Noisy sample, shape [*, dim]
3737
- `v` - Vector field, shape [*, dim]
3738
- `t` - Interpolation time, shape [*] (must be < 1)
3739
3740
3741
**Returns**:
3742
3743
  Score of intermediate density, shape [*, dim].
3744
3745
<a id="mocointerpolantscontinuous_timecontinuouscontinuous_flow_matchingContinuousFlowMatcherget_gt"></a>
3746
3747
#### get\_gt
3748
3749
```python
3750
def get_gt(t: Tensor,
3751
           mode: str = "tan",
3752
           param: float = 1.0,
3753
           clamp_val: Optional[float] = None,
3754
           eps: float = 1e-2) -> Tensor
3755
```
3756
3757
From Geffner et al. Computes gt for different modes.
3758
3759
**Arguments**:
3760
3761
- `t` - times where we'll evaluate, covers [0, 1), shape [nsteps]
3762
- `mode` - "us" or "tan"
3763
- `param` - parameterized transformation
3764
- `clamp_val` - value to clamp gt, no clamping if None
3765
- `eps` - small value leave as it is
3766
3767
<a id="mocointerpolantscontinuous_time"></a>
3768
3769
# bionemo.moco.interpolants.continuous\_time
3770
3771
<a id="mocointerpolants"></a>
3772
3773
# bionemo.moco.interpolants
3774
3775
<a id="mocointerpolantsbatch_augmentation"></a>
3776
3777
# bionemo.moco.interpolants.batch\_augmentation
3778
3779
<a id="mocointerpolantsbatch_augmentationBatchDataAugmentation"></a>
3780
3781
## BatchDataAugmentation Objects
3782
3783
```python
3784
class BatchDataAugmentation()
3785
```
3786
3787
Facilitates the creation of batch augmentation objects based on specified optimal transport types.
3788
3789
**Arguments**:
3790
3791
- `device` _str_ - The device to use for computations (e.g., 'cpu', 'cuda').
3792
- `num_threads` _int_ - The number of threads to utilize.
3793
3794
<a id="mocointerpolantsbatch_augmentationBatchDataAugmentation__init__"></a>
3795
3796
#### \_\_init\_\_
3797
3798
```python
3799
def __init__(device, num_threads)
3800
```
3801
3802
Initializes a BatchAugmentation instance.
3803
3804
**Arguments**:
3805
3806
- `device` _str_ - Device for computation.
3807
- `num_threads` _int_ - Number of threads to use.
3808
3809
<a id="mocointerpolantsbatch_augmentationBatchDataAugmentationcreate"></a>
3810
3811
#### create
3812
3813
```python
3814
def create(method_type: AugmentationType)
3815
```
3816
3817
Creates a batch augmentation object of the specified type.
3818
3819
**Arguments**:
3820
3821
- `method_type` _AugmentationType_ - The type of optimal transport method.
3822
3823
3824
**Returns**:
3825
3826
  The augmentation object if the type is supported, otherwise **None**.
3827
3828
<a id="mocointerpolantsdiscrete_timediscreted3pm"></a>
3829
3830
# bionemo.moco.interpolants.discrete\_time.discrete.d3pm
3831
3832
<a id="mocointerpolantsdiscrete_timediscreted3pmD3PM"></a>
3833
3834
## D3PM Objects
3835
3836
```python
3837
class D3PM(Interpolant)
3838
```
3839
3840
A Discrete Denoising Diffusion Probabilistic Model (D3PM) interpolant.
3841
3842
<a id="mocointerpolantsdiscrete_timediscreted3pmD3PM__init__"></a>
3843
3844
#### \_\_init\_\_
3845
3846
```python
3847
def __init__(time_distribution: TimeDistribution,
3848
             prior_distribution: DiscretePriorDistribution,
3849
             noise_schedule: DiscreteNoiseSchedule,
3850
             device: str = "cpu",
3851
             last_time_idx: int = 0,
3852
             rng_generator: Optional[torch.Generator] = None)
3853
```
3854
3855
Initializes the D3PM interpolant.
3856
3857
**Arguments**:
3858
3859
- `time_distribution` _TimeDistribution_ - The distribution of time steps, used to sample time points for the diffusion process.
3860
- `prior_distribution` _PriorDistribution_ - The prior distribution of the variable, used as the starting point for the diffusion process.
3861
- `noise_schedule` _DiscreteNoiseSchedule_ - The schedule of noise, defining the amount of noise added at each time step.
3862
- `device` _str, optional_ - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
3863
- `last_time_idx` _int, optional_ - The last time index to consider in the interpolation process. Defaults to 0.
3864
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
3865
3866
<a id="mocointerpolantsdiscrete_timediscreted3pmD3PMinterpolate"></a>
3867
3868
#### interpolate
3869
3870
```python
3871
def interpolate(data: Tensor, t: Tensor)
3872
```
3873
3874
Interpolate using discrete interpolation method.
3875
3876
This method implements Equation 2 from the D3PM paper (https://arxiv.org/pdf/2107.03006), which
3877
calculates the interpolated discrete state `xt` at time `t` given the input data and noise
3878
via q(xt|x0) = Cat(xt; p = x0*Qt_bar).
3879
3880
**Arguments**:
3881
3882
- `data` _Tensor_ - The input data to be interpolated.
3883
- `t` _Tensor_ - The time step at which to interpolate.
3884
3885
3886
**Returns**:
3887
3888
- `Tensor` - The interpolated discrete state `xt` at time `t`.
3889
3890
<a id="mocointerpolantsdiscrete_timediscreted3pmD3PMforward_process"></a>
3891
3892
#### forward\_process
3893
3894
```python
3895
def forward_process(data: Tensor, t: Tensor) -> Tensor
3896
```
3897
3898
Apply the forward process to the data at time t.
3899
3900
**Arguments**:
3901
3902
- `data` _Tensor_ - target discrete ids
3903
- `t` _Tensor_ - time
3904
3905
3906
**Returns**:
3907
3908
- `Tensor` - x(t) after applying the forward process
3909
3910
<a id="mocointerpolantsdiscrete_timediscreted3pmD3PMstep"></a>
3911
3912
#### step
3913
3914
```python
3915
def step(model_out: Tensor,
3916
         t: Tensor,
3917
         xt: Tensor,
3918
         mask: Optional[Tensor] = None,
3919
         temperature: Float = 1.0,
3920
         model_out_is_logits: bool = True)
3921
```
3922
3923
Perform a single step in the discrete interpolant method, transitioning from the current discrete state `xt` at time `t` to the next state.
3924
3925
This step involves:
3926
3927
1. Computing the predicted q-posterior logits using the model output `model_out` and the current state `xt` at time `t`.
3928
2. Sampling the next state from the predicted q-posterior distribution using the Gumbel-Softmax trick.
3929
3930
**Arguments**:
3931
3932
- `model_out` _Tensor_ - The output of the model at the current time step, which is used to compute the predicted q-posterior logits.
3933
- `t` _Tensor_ - The current time step, which is used to index into the transition matrices and compute the predicted q-posterior logits.
3934
- `xt` _Tensor_ - The current discrete state at time `t`, which is used to compute the predicted q-posterior logits and sample the next state.
3935
- `mask` _Optional[Tensor], optional_ - An optional mask to apply to the next state, which can be used to mask out certain tokens or regions. Defaults to None.
3936
- `temperature` _Float, optional_ - The temperature to use for the Gumbel-Softmax trick, which controls the randomness of the sampling process. Defaults to 1.0.
3937
- `model_out_is_logits` _bool, optional_ - A flag indicating whether the model output is already in logits form. If True, the output is assumed to be logits; otherwise, it is converted to logits. Defaults to True.
3938
3939
3940
**Returns**:
3941
3942
- `Tensor` - The next discrete state at time `t-1`.
3943
3944
<a id="mocointerpolantsdiscrete_timediscreted3pmD3PMloss"></a>
3945
3946
#### loss
3947
3948
```python
3949
def loss(logits: Tensor,
3950
         target: Tensor,
3951
         xt: Tensor,
3952
         time: Tensor,
3953
         mask: Optional[Tensor] = None,
3954
         vb_scale: Float = 0.0)
3955
```
3956
3957
Calculate the cross-entropy loss between the model prediction and the target output.
3958
3959
The loss is calculated between the batch x node x class logits and the target batch x node. If a mask is provided, the loss is
3960
calculated only for the non-masked elements. Additionally, if vb_scale is greater than 0, the variational lower bound loss is
3961
calculated and added to the total loss.
3962
3963
**Arguments**:
3964
3965
- `logits` _Tensor_ - The predicted output from the model, with shape batch x node x class.
3966
- `target` _Tensor_ - The target output for the model prediction, with shape batch x node.
3967
- `xt` _Tensor_ - The current data point.
3968
- `time` _Tensor_ - The time at which the loss is calculated.
3969
- `mask` _Optional[Tensor], optional_ - The mask for the data point. Defaults to None.
3970
- `vb_scale` _Float, optional_ - The scale factor for the variational lower bound loss. Defaults to 0.0.
3971
3972
3973
**Returns**:
3974
3975
- `Tensor` - The calculated loss tensor. If aggregate is True, the loss and variational lower bound loss are aggregated and
3976
  returned as a single tensor. Otherwise, the loss and variational lower bound loss are returned as separate tensors.
3977
3978
<a id="mocointerpolantsdiscrete_timediscrete"></a>
3979
3980
# bionemo.moco.interpolants.discrete\_time.discrete
3981
3982
<a id="mocointerpolantsdiscrete_timecontinuousddpm"></a>
3983
3984
# bionemo.moco.interpolants.discrete\_time.continuous.ddpm
3985
3986
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPM"></a>
3987
3988
## DDPM Objects
3989
3990
```python
3991
class DDPM(Interpolant)
3992
```
3993
3994
A Denoising Diffusion Probabilistic Model (DDPM) interpolant.
3995
3996
-------
3997
3998
**Examples**:
3999
4000
```python
4001
>>> import torch
4002
>>> from bionemo.bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
4003
>>> from bionemo.bionemo.moco.distributions.time.uniform import UniformTimeDistribution
4004
>>> from bionemo.bionemo.moco.interpolants.discrete_time.continuous.ddpm import DDPM
4005
>>> from bionemo.bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteCosineNoiseSchedule
4006
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import DiscreteLinearInferenceSchedule
4007
4008
4009
ddpm = DDPM(
4010
    time_distribution = UniformTimeDistribution(discrete_time = True,...),
4011
    prior_distribution = GaussianPrior(...),
4012
    noise_schedule = DiscreteCosineNoiseSchedule(...),
4013
    )
4014
model = Model(...)
4015
4016
# Training
4017
for epoch in range(1000):
4018
    data = data_loader.get(...)
4019
    time = ddpm.sample_time(batch_size)
4020
    noise = ddpm.sample_prior(data.shape)
4021
    xt = ddpm.interpolate(data, noise, time)
4022
4023
    x_pred = model(xt, time)
4024
    loss = ddpm.loss(x_pred, data, time)
4025
    loss.backward()
4026
4027
# Generation
4028
x_pred = ddpm.sample_prior(data.shape)
4029
for t in DiscreteLinearTimeSchedule(...).generate_schedule():
4030
    time = torch.full((batch_size,), t)
4031
    x_hat = model(x_pred, time)
4032
    x_pred = ddpm.step(x_hat, time, x_pred)
4033
return x_pred
4034
4035
```
4036
4037
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPM__init__"></a>
4038
4039
#### \_\_init\_\_
4040
4041
```python
4042
def __init__(time_distribution: TimeDistribution,
4043
             prior_distribution: PriorDistribution,
4044
             noise_schedule: DiscreteNoiseSchedule,
4045
             prediction_type: Union[PredictionType, str] = PredictionType.DATA,
4046
             device: Union[str, torch.device] = "cpu",
4047
             last_time_idx: int = 0,
4048
             rng_generator: Optional[torch.Generator] = None)
4049
```
4050
4051
Initializes the DDPM interpolant.
4052
4053
**Arguments**:
4054
4055
- `time_distribution` _TimeDistribution_ - The distribution of time steps, used to sample time points for the diffusion process.
4056
- `prior_distribution` _PriorDistribution_ - The prior distribution of the variable, used as the starting point for the diffusion process.
4057
- `noise_schedule` _DiscreteNoiseSchedule_ - The schedule of noise, defining the amount of noise added at each time step.
4058
- `prediction_type` _PredictionType_ - The type of prediction, either "data" or another type. Defaults to "data".
4059
- `device` _str_ - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
4060
- `last_time_idx` _int, optional_ - The last time index for discrete time. Set to 0 if discrete time is T-1, ..., 0 or 1 if T, ..., 1. Defaults to 0.
4061
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
4062
4063
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMforward_data_schedule"></a>
4064
4065
#### forward\_data\_schedule
4066
4067
```python
4068
@property
4069
def forward_data_schedule() -> torch.Tensor
4070
```
4071
4072
Returns the forward data schedule.
4073
4074
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMforward_noise_schedule"></a>
4075
4076
#### forward\_noise\_schedule
4077
4078
```python
4079
@property
4080
def forward_noise_schedule() -> torch.Tensor
4081
```
4082
4083
Returns the forward noise schedule.
4084
4085
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMreverse_data_schedule"></a>
4086
4087
#### reverse\_data\_schedule
4088
4089
```python
4090
@property
4091
def reverse_data_schedule() -> torch.Tensor
4092
```
4093
4094
Returns the reverse data schedule.
4095
4096
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMreverse_noise_schedule"></a>
4097
4098
#### reverse\_noise\_schedule
4099
4100
```python
4101
@property
4102
def reverse_noise_schedule() -> torch.Tensor
4103
```
4104
4105
Returns the reverse noise schedule.
4106
4107
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMlog_var"></a>
4108
4109
#### log\_var
4110
4111
```python
4112
@property
4113
def log_var() -> torch.Tensor
4114
```
4115
4116
Returns the log variance.
4117
4118
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMalpha_bar"></a>
4119
4120
#### alpha\_bar
4121
4122
```python
4123
@property
4124
def alpha_bar() -> torch.Tensor
4125
```
4126
4127
Returns the alpha bar values.
4128
4129
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMalpha_bar_prev"></a>
4130
4131
#### alpha\_bar\_prev
4132
4133
```python
4134
@property
4135
def alpha_bar_prev() -> torch.Tensor
4136
```
4137
4138
Returns the previous alpha bar values.
4139
4140
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMinterpolate"></a>
4141
4142
#### interpolate
4143
4144
```python
4145
def interpolate(data: Tensor, t: Tensor, noise: Tensor)
4146
```
4147
4148
Get x(t) with given time t from noise and data.
4149
4150
**Arguments**:
4151
4152
- `data` _Tensor_ - target
4153
- `t` _Tensor_ - time
4154
- `noise` _Tensor_ - noise from prior()
4155
4156
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMforward_process"></a>
4157
4158
#### forward\_process
4159
4160
```python
4161
def forward_process(data: Tensor, t: Tensor, noise: Optional[Tensor] = None)
4162
```
4163
4164
Get x(t) with given time t from noise and data.
4165
4166
**Arguments**:
4167
4168
- `data` _Tensor_ - target
4169
- `t` _Tensor_ - time
4170
- `noise` _Tensor, optional_ - noise from prior(). Defaults to None.
4171
4172
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMprocess_data_prediction"></a>
4173
4174
#### process\_data\_prediction
4175
4176
```python
4177
def process_data_prediction(model_output: Tensor, sample: Tensor, t: Tensor)
4178
```
4179
4180
Converts the model output to a data prediction based on the prediction type.
4181
4182
This conversion stems from the Progressive Distillation for Fast Sampling of Diffusion Models https://arxiv.org/pdf/2202.00512.
4183
Given the model output and the sample, we convert the output to a data prediction based on the prediction type.
4184
The conversion formulas are as follows:
4185
- For "noise" prediction type: `pred_data = (sample - noise_scale * model_output) / data_scale`
4186
- For "data" prediction type: `pred_data = model_output`
4187
- For "v_prediction" prediction type: `pred_data = data_scale * sample - noise_scale * model_output`
4188
4189
**Arguments**:
4190
4191
- `model_output` _Tensor_ - The output of the model.
4192
- `sample` _Tensor_ - The input sample.
4193
- `t` _Tensor_ - The time step.
4194
4195
4196
**Returns**:
4197
4198
  The data prediction based on the prediction type.
4199
4200
4201
**Raises**:
4202
4203
- `ValueError` - If the prediction type is not one of "noise", "data", or "v_prediction".
4204
4205
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMprocess_noise_prediction"></a>
4206
4207
#### process\_noise\_prediction
4208
4209
```python
4210
def process_noise_prediction(model_output, sample, t)
4211
```
4212
4213
Do the same as process_data_prediction but take the model output and convert to nosie.
4214
4215
**Arguments**:
4216
4217
- `model_output` - The output of the model.
4218
- `sample` - The input sample.
4219
- `t` - The time step.
4220
4221
4222
**Returns**:
4223
4224
  The input as noise if the prediction type is "noise".
4225
4226
4227
**Raises**:
4228
4229
- `ValueError` - If the prediction type is not "noise".
4230
4231
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMcalculate_velocity"></a>
4232
4233
#### calculate\_velocity
4234
4235
```python
4236
def calculate_velocity(data: Tensor, t: Tensor, noise: Tensor) -> Tensor
4237
```
4238
4239
Calculate the velocity term given the data, time step, and noise.
4240
4241
**Arguments**:
4242
4243
- `data` _Tensor_ - The input data.
4244
- `t` _Tensor_ - The current time step.
4245
- `noise` _Tensor_ - The noise term.
4246
4247
4248
**Returns**:
4249
4250
- `Tensor` - The calculated velocity term.
4251
4252
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMstep"></a>
4253
4254
#### step
4255
4256
```python
4257
@torch.no_grad()
4258
def step(model_out: Tensor,
4259
         t: Tensor,
4260
         xt: Tensor,
4261
         mask: Optional[Tensor] = None,
4262
         center: Bool = False,
4263
         temperature: Float = 1.0)
4264
```
4265
4266
Do one step integration.
4267
4268
**Arguments**:
4269
4270
- `model_out` _Tensor_ - The output of the model.
4271
- `t` _Tensor_ - The current time step.
4272
- `xt` _Tensor_ - The current data point.
4273
- `mask` _Optional[Tensor], optional_ - An optional mask to apply to the data. Defaults to None.
4274
- `center` _bool, optional_ - Whether to center the data. Defaults to False.
4275
- `temperature` _Float, optional_ - The temperature parameter for low temperature sampling. Defaults to 1.0.
4276
4277
4278
**Notes**:
4279
4280
  The temperature parameter controls the level of randomness in the sampling process. A temperature of 1.0 corresponds to standard diffusion sampling, while lower temperatures (e.g. 0.5, 0.2) result in less random and more deterministic samples. This can be useful for tasks that require more control over the generation process.
4281
4282
  Note for discrete time we sample from [T-1, ..., 1, 0] for T steps so we sample t = 0 hence the mask.
4283
  For continuous time we start from [1, 1 -dt, ..., dt] for T steps where s = t - 1 when t = 0 i.e dt is then 0
4284
4285
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMstep_noise"></a>
4286
4287
#### step\_noise
4288
4289
```python
4290
def step_noise(model_out: Tensor,
4291
               t: Tensor,
4292
               xt: Tensor,
4293
               mask: Optional[Tensor] = None,
4294
               center: Bool = False,
4295
               temperature: Float = 1.0)
4296
```
4297
4298
Do one step integration.
4299
4300
**Arguments**:
4301
4302
- `model_out` _Tensor_ - The output of the model.
4303
- `t` _Tensor_ - The current time step.
4304
- `xt` _Tensor_ - The current data point.
4305
- `mask` _Optional[Tensor], optional_ - An optional mask to apply to the data. Defaults to None.
4306
- `center` _bool, optional_ - Whether to center the data. Defaults to False.
4307
- `temperature` _Float, optional_ - The temperature parameter for low temperature sampling. Defaults to 1.0.
4308
4309
4310
**Notes**:
4311
4312
  The temperature parameter controls the level of randomness in the sampling process.
4313
  A temperature of 1.0 corresponds to standard diffusion sampling, while lower temperatures (e.g. 0.5, 0.2)
4314
  result in less random and more deterministic samples. This can be useful for tasks
4315
  that require more control over the generation process.
4316
4317
  Note for discrete time we sample from [T-1, ..., 1, 0] for T steps so we sample t = 0 hence the mask.
4318
  For continuous time we start from [1, 1 -dt, ..., dt] for T steps where s = t - 1 when t = 0 i.e dt is then 0
4319
4320
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMscore"></a>
4321
4322
#### score
4323
4324
```python
4325
def score(x_hat: Tensor, xt: Tensor, t: Tensor)
4326
```
4327
4328
Converts the data prediction to the estimated score function.
4329
4330
**Arguments**:
4331
4332
- `x_hat` _Tensor_ - The predicted data point.
4333
- `xt` _Tensor_ - The current data point.
4334
- `t` _Tensor_ - The time step.
4335
4336
4337
**Returns**:
4338
4339
  The estimated score function.
4340
4341
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMstep_ddim"></a>
4342
4343
#### step\_ddim
4344
4345
```python
4346
def step_ddim(model_out: Tensor,
4347
              t: Tensor,
4348
              xt: Tensor,
4349
              mask: Optional[Tensor] = None,
4350
              eta: Float = 0.0,
4351
              center: Bool = False)
4352
```
4353
4354
Do one step of DDIM sampling.
4355
4356
**Arguments**:
4357
4358
- `model_out` _Tensor_ - output of the model
4359
- `t` _Tensor_ - current time step
4360
- `xt` _Tensor_ - current data point
4361
- `mask` _Optional[Tensor], optional_ - mask for the data point. Defaults to None.
4362
- `eta` _Float, optional_ - DDIM sampling parameter. Defaults to 0.0.
4363
- `center` _Bool, optional_ - whether to center the data point. Defaults to False.
4364
4365
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMset_loss_weight_fn"></a>
4366
4367
#### set\_loss\_weight\_fn
4368
4369
```python
4370
def set_loss_weight_fn(fn)
4371
```
4372
4373
Sets the loss_weight attribute of the instance to the given function.
4374
4375
**Arguments**:
4376
4377
- `fn` - The function to set as the loss_weight attribute. This function should take three arguments: raw_loss, t, and weight_type.
4378
4379
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMloss_weight"></a>
4380
4381
#### loss\_weight
4382
4383
```python
4384
def loss_weight(raw_loss: Tensor, t: Optional[Tensor],
4385
                weight_type: str) -> Tensor
4386
```
4387
4388
Calculates the weight for the loss based on the given weight type.
4389
4390
These data_to_noise loss weights is derived in Equation (9) of https://arxiv.org/pdf/2202.00512.
4391
4392
**Arguments**:
4393
4394
- `raw_loss` _Tensor_ - The raw loss calculated from the model prediction and target.
4395
- `t` _Tensor_ - The time step.
4396
- `weight_type` _str_ - The type of weight to use. Can be "ones" or "data_to_noise" or "noise_to_data".
4397
4398
4399
**Returns**:
4400
4401
- `Tensor` - The weight for the loss.
4402
4403
4404
**Raises**:
4405
4406
- `ValueError` - If the weight type is not recognized.
4407
4408
<a id="mocointerpolantsdiscrete_timecontinuousddpmDDPMloss"></a>
4409
4410
#### loss
4411
4412
```python
4413
def loss(model_pred: Tensor,
4414
         target: Tensor,
4415
         t: Optional[Tensor] = None,
4416
         mask: Optional[Tensor] = None,
4417
         weight_type: Literal["ones", "data_to_noise",
4418
                              "noise_to_data"] = "ones")
4419
```
4420
4421
Calculate the loss given the model prediction, data sample, and time.
4422
4423
The default weight_type is "ones" meaning no change / multiplying by all ones.
4424
data_to_noise is available to scale the data MSE loss into the appropriate loss that is theoretically equivalent
4425
to noise prediction. noise_to_data is provided for a similar reason for completeness.
4426
4427
**Arguments**:
4428
4429
- `model_pred` _Tensor_ - The predicted output from the model.
4430
- `target` _Tensor_ - The target output for the model prediction.
4431
- `t` _Tensor_ - The time at which the loss is calculated.
4432
- `mask` _Optional[Tensor], optional_ - The mask for the data point. Defaults to None.
4433
- `weight_type` _Literal["ones", "data_to_noise", "noise_to_data"]_ - The type of weight to use for the loss. Defaults to "ones".
4434
4435
4436
**Returns**:
4437
4438
- `Tensor` - The calculated loss batch tensor.
4439
4440
<a id="mocointerpolantsdiscrete_timecontinuous"></a>
4441
4442
# bionemo.moco.interpolants.discrete\_time.continuous
4443
4444
<a id="mocointerpolantsdiscrete_time"></a>
4445
4446
# bionemo.moco.interpolants.discrete\_time
4447
4448
<a id="mocointerpolantsdiscrete_timeutils"></a>
4449
4450
# bionemo.moco.interpolants.discrete\_time.utils
4451
4452
<a id="mocointerpolantsdiscrete_timeutilssafe_index"></a>
4453
4454
#### safe\_index
4455
4456
```python
4457
def safe_index(tensor: Tensor, index: Tensor, device: Optional[torch.device])
4458
```
4459
4460
Safely indexes a tensor using a given index and returns the result on a specified device.
4461
4462
Note can implement forcing with  return tensor[index.to(tensor.device)].to(device) but has costly migration.
4463
4464
**Arguments**:
4465
4466
- `tensor` _Tensor_ - The tensor to be indexed.
4467
- `index` _Tensor_ - The index to use for indexing the tensor.
4468
- `device` _torch.device_ - The device on which the result should be returned.
4469
4470
4471
**Returns**:
4472
4473
- `Tensor` - The indexed tensor on the specified device.
4474
4475
4476
**Raises**:
4477
4478
- `ValueError` - If tensor, index are not all on the same device.
4479
4480
<a id="mocointerpolantsbase_interpolant"></a>
4481
4482
# bionemo.moco.interpolants.base\_interpolant
4483
4484
<a id="mocointerpolantsbase_interpolantstring_to_enum"></a>
4485
4486
#### string\_to\_enum
4487
4488
```python
4489
def string_to_enum(value: Union[str, AnyEnum],
4490
                   enum_type: Type[AnyEnum]) -> AnyEnum
4491
```
4492
4493
Converts a string to an enum value of the specified type. If the input is already an enum instance, it is returned as-is.
4494
4495
**Arguments**:
4496
4497
- `value` _Union[str, E]_ - The string to convert or an existing enum instance.
4498
- `enum_type` _Type[E]_ - The enum type to convert to.
4499
4500
4501
**Returns**:
4502
4503
- `E` - The corresponding enum value.
4504
4505
4506
**Raises**:
4507
4508
- `ValueError` - If the string does not correspond to any enum member.
4509
4510
<a id="mocointerpolantsbase_interpolantpad_like"></a>
4511
4512
#### pad\_like
4513
4514
```python
4515
def pad_like(source: Tensor, target: Tensor) -> Tensor
4516
```
4517
4518
Pads the dimensions of the source tensor to match the dimensions of the target tensor.
4519
4520
**Arguments**:
4521
4522
- `source` _Tensor_ - The tensor to be padded.
4523
- `target` _Tensor_ - The tensor that the source tensor should match in dimensions.
4524
4525
4526
**Returns**:
4527
4528
- `Tensor` - The padded source tensor.
4529
4530
4531
**Raises**:
4532
4533
- `ValueError` - If the source tensor has more dimensions than the target tensor.
4534
4535
4536
**Example**:
4537
4538
  >>> source = torch.tensor([1, 2, 3])  # shape: (3,)
4539
  >>> target = torch.tensor([[1, 2], [4, 5], [7, 8]])  # shape: (3, 2)
4540
  >>> padded_source = pad_like(source, target)  # shape: (3, 1)
4541
4542
<a id="mocointerpolantsbase_interpolantPredictionType"></a>
4543
4544
## PredictionType Objects
4545
4546
```python
4547
class PredictionType(Enum)
4548
```
4549
4550
An enumeration representing the type of prediction a Denoising Diffusion Probabilistic Model (DDPM) can be used for.
4551
4552
DDPMs are versatile models that can be utilized for various prediction tasks, including:
4553
4554
- **Data**: Predicting the original data distribution from a noisy input.
4555
- **Noise**: Predicting the noise that was added to the original data to obtain the input.
4556
- **Velocity**: Predicting the velocity or rate of change of the data, particularly useful for modeling temporal dynamics.
4557
4558
These prediction types can be used to train neural networks for specific tasks, such as denoising, image synthesis, or time-series forecasting.
4559
4560
<a id="mocointerpolantsbase_interpolantInterpolant"></a>
4561
4562
## Interpolant Objects
4563
4564
```python
4565
class Interpolant(ABC)
4566
```
4567
4568
An abstract base class representing an Interpolant.
4569
4570
This class serves as a foundation for creating interpolants that can be used
4571
in various applications, providing a basic structure and interface for
4572
interpolation-related operations.
4573
4574
<a id="mocointerpolantsbase_interpolantInterpolant__init__"></a>
4575
4576
#### \_\_init\_\_
4577
4578
```python
4579
def __init__(time_distribution: TimeDistribution,
4580
             prior_distribution: PriorDistribution,
4581
             device: Union[str, torch.device] = "cpu",
4582
             rng_generator: Optional[torch.Generator] = None)
4583
```
4584
4585
Initializes the Interpolant class.
4586
4587
**Arguments**:
4588
4589
- `time_distribution` _TimeDistribution_ - The distribution of time steps.
4590
- `prior_distribution` _PriorDistribution_ - The prior distribution of the variable.
4591
- `device` _Union[str, torch.device], optional_ - The device on which to operate. Defaults to "cpu".
4592
- `rng_generator` - An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
4593
4594
<a id="mocointerpolantsbase_interpolantInterpolantinterpolate"></a>
4595
4596
#### interpolate
4597
4598
```python
4599
@abstractmethod
4600
def interpolate(*args, **kwargs) -> Tensor
4601
```
4602
4603
Get x(t) with given time t from noise and data.
4604
4605
Interpolate between x0 and x1 at the given time t.
4606
4607
<a id="mocointerpolantsbase_interpolantInterpolantstep"></a>
4608
4609
#### step
4610
4611
```python
4612
@abstractmethod
4613
def step(*args, **kwargs) -> Tensor
4614
```
4615
4616
Do one step integration.
4617
4618
<a id="mocointerpolantsbase_interpolantInterpolantgeneral_step"></a>
4619
4620
#### general\_step
4621
4622
```python
4623
def general_step(method_name: str, kwargs: dict)
4624
```
4625
4626
Calls a step method of the class by its name, passing the provided keyword arguments.
4627
4628
**Arguments**:
4629
4630
- `method_name` _str_ - The name of the step method to call.
4631
- `kwargs` _dict_ - Keyword arguments to pass to the step method.
4632
4633
4634
**Returns**:
4635
4636
  The result of the step method call.
4637
4638
4639
**Raises**:
4640
4641
- `ValueError` - If the provided method name does not start with 'step'.
4642
- `Exception` - If the step method call fails. The error message includes a list of available step methods.
4643
4644
4645
**Notes**:
4646
4647
  This method allows for dynamic invocation of step methods, providing flexibility in the class's usage.
4648
4649
<a id="mocointerpolantsbase_interpolantInterpolantsample_prior"></a>
4650
4651
#### sample\_prior
4652
4653
```python
4654
def sample_prior(*args, **kwargs) -> Tensor
4655
```
4656
4657
Sample from prior distribution.
4658
4659
This method generates a sample from the prior distribution specified by the
4660
`prior_distribution` attribute.
4661
4662
**Returns**:
4663
4664
- `Tensor` - The generated sample from the prior distribution.
4665
4666
<a id="mocointerpolantsbase_interpolantInterpolantsample_time"></a>
4667
4668
#### sample\_time
4669
4670
```python
4671
def sample_time(*args, **kwargs) -> Tensor
4672
```
4673
4674
Sample from time distribution.
4675
4676
<a id="mocointerpolantsbase_interpolantInterpolantto_device"></a>
4677
4678
#### to\_device
4679
4680
```python
4681
def to_device(device: str)
4682
```
4683
4684
Moves all internal tensors to the specified device and updates the `self.device` attribute.
4685
4686
**Arguments**:
4687
4688
- `device` _str_ - The device to move the tensors to (e.g. "cpu", "cuda:0").
4689
4690
4691
**Notes**:
4692
4693
  This method is used to transfer the internal state of the DDPM interpolant to a different device.
4694
  It updates the `self.device` attribute to reflect the new device and moves all internal tensors to the specified device.
4695
4696
<a id="mocointerpolantsbase_interpolantInterpolantclean_mask_center"></a>
4697
4698
#### clean\_mask\_center
4699
4700
```python
4701
def clean_mask_center(data: Tensor,
4702
                      mask: Optional[Tensor] = None,
4703
                      center: Bool = False) -> Tensor
4704
```
4705
4706
Returns a clean tensor that has been masked and/or centered based on the function arguments.
4707
4708
**Arguments**:
4709
4710
- `data` - The input data with shape (..., nodes, features).
4711
- `mask` - An optional mask to apply to the data with shape (..., nodes). If provided, it is used to calculate the CoM. Defaults to None.
4712
- `center` - A boolean indicating whether to center the data around the calculated CoM. Defaults to False.
4713
4714
4715
**Returns**:
4716
4717
  The data with shape (..., nodes, features) either centered around the CoM if `center` is True or unchanged if `center` is False.
4718
4719
<a id="mocotesting"></a>
4720
4721
# bionemo.moco.testing
4722
4723
<a id="mocotestingparallel_test_utils"></a>
4724
4725
# bionemo.moco.testing.parallel\_test\_utils
4726
4727
<a id="mocotestingparallel_test_utilsparallel_context"></a>
4728
4729
#### parallel\_context
4730
4731
```python
4732
@contextmanager
4733
def parallel_context(rank: int = 0, world_size: int = 1)
4734
```
4735
4736
Context manager for torch distributed testing.
4737
4738
Sets up and cleans up the distributed environment, including the device mesh.
4739
4740
**Arguments**:
4741
4742
- `rank` _int_ - The rank of the process. Defaults to 0.
4743
- `world_size` _int_ - The world size of the distributed environment. Defaults to 1.
4744
4745
4746
**Yields**:
4747
4748
  None
4749
4750
<a id="mocotestingparallel_test_utilsclean_up_distributed"></a>
4751
4752
#### clean\_up\_distributed
4753
4754
```python
4755
def clean_up_distributed() -> None
4756
```
4757
4758
Cleans up the distributed environment.
4759
4760
Destroys the process group and empties the CUDA cache.
4761
4762
**Arguments**:
4763
4764
  None
4765
4766
4767
**Returns**:
4768
4769
  None