[be1279]: / sub-packages / bionemo-moco / documentation.md

Download this file

4770 lines (3034 with data), 146.6 kB

Table of Contents

moco

bionemo.moco.distributions

bionemo.moco.distributions.prior.distribution

PriorDistribution Objects

class PriorDistribution(ABC)

An abstract base class representing a prior distribution.

sample

@abstractmethod
def sample(shape: Tuple,
           mask: Optional[Tensor] = None,
           device: Union[str, torch.device] = "cpu") -> Tensor

Generates a specified number of samples from the time distribution.

Arguments:

  • shape Tuple - The shape of the samples to generate.
  • mask Optional[Tensor], optional - A tensor indicating which samples should be masked. Defaults to None.
  • device str, optional - The device on which to generate the samples. Defaults to "cpu".

Returns:

  • Float - A tensor of samples.

DiscretePriorDistribution Objects

class DiscretePriorDistribution(PriorDistribution)

An abstract base class representing a discrete prior distribution.

__init__

def __init__(num_classes: int, prior_dist: Tensor)

Initializes a DiscretePriorDistribution instance.

Arguments:

  • num_classes int - The number of classes in the discrete distribution.
  • prior_dist Tensor - The prior distribution over the classes.

Returns:

None

get_num_classes

def get_num_classes() -> int

Getter for num_classes.

get_prior_dist

def get_prior_dist() -> Tensor

Getter for prior_dist.

bionemo.moco.distributions.prior.discrete.uniform

DiscreteUniformPrior Objects

class DiscreteUniformPrior(DiscretePriorDistribution)

A subclass representing a discrete uniform prior distribution.

__init__

def __init__(num_classes: int = 10) -> None

Initializes a discrete uniform prior distribution.

Arguments:

  • num_classes int - The number of classes in the discrete uniform distribution. Defaults to 10.

sample

def sample(shape: Tuple,
           mask: Optional[Tensor] = None,
           device: Union[str, torch.device] = "cpu",
           rng_generator: Optional[torch.Generator] = None) -> Tensor

Generates a specified number of samples.

Arguments:

  • shape Tuple - The shape of the samples to generate.
  • device str - cpu or gpu.
  • mask Optional[Tensor] - An optional mask to apply to the samples. Defaults to None.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

Returns:

  • Float - A tensor of samples.

bionemo.moco.distributions.prior.discrete.custom

DiscreteCustomPrior Objects

class DiscreteCustomPrior(DiscretePriorDistribution)

A subclass representing a discrete custom prior distribution.

This class allows for the creation of a prior distribution with a custom
probability mass function defined by the prior_dist tensor. For example if my data has 4 classes and I want [.3, .2, .4, .1] as the probabilities of the 4 classes.

__init__

def __init__(prior_dist: Tensor, num_classes: int = 10) -> None

Initializes a DiscreteCustomPrior distribution.

Arguments:

  • prior_dist - A tensor representing the probability mass function of the prior distribution.
  • num_classes - The number of classes in the prior distribution. Defaults to 10.

Notes:

The prior_dist tensor should have a sum close to 1.0, as it represents a probability mass function.

sample

def sample(shape: Tuple,
           mask: Optional[Tensor] = None,
           device: Union[str, torch.device] = "cpu",
           rng_generator: Optional[torch.Generator] = None) -> Tensor

Samples from the discrete custom prior distribution.

Arguments:

  • shape - A tuple specifying the shape of the samples to generate.
  • mask - An optional tensor mask to apply to the samples, broadcastable to the sample shape. Defaults to None.
  • device - The device on which to generate the samples, specified as a string or a :class:torch.device. Defaults to "cpu".
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

Returns:

A tensor of samples drawn from the prior distribution.

bionemo.moco.distributions.prior.discrete

bionemo.moco.distributions.prior.discrete.mask

DiscreteMaskedPrior Objects

class DiscreteMaskedPrior(DiscretePriorDistribution)

A subclass representing a Discrete Masked prior distribution.

__init__

def __init__(num_classes: int = 10,
             mask_dim: Optional[int] = None,
             inclusive: bool = True) -> None

Discrete Masked prior distribution.

Theres 3 ways I can think of defining the problem that are hard to mesh together.

  1. [..., M, ....] inclusive anywhere --> exisiting LLM tokenizer where the mask has a specific location not at the end
  2. [......, M] inclusive on end --> mask_dim = None with inclusive set to True default stick on the end
  3. [.....] + [M] exclusive --> the number of classes representes the number of data classes and one wishes to add a separate MASK dimension.
  4. Note the pad_sample function is provided to help add this extra external dimension.

Arguments:

  • num_classes int - The number of classes in the distribution. Defaults to 10.
  • mask_dim int - The index for the mask token. Defaults to num_classes - 1 if inclusive or num_classes if exclusive.
  • inclusive bool - Whether the mask is included in the specified number of classes.
    If True, the mask is considered as one of the classes.
    If False, the mask is considered as an additional class. Defaults to True.

sample

def sample(shape: Tuple,
           mask: Optional[Tensor] = None,
           device: Union[str, torch.device] = "cpu",
           rng_generator: Optional[torch.Generator] = None) -> Tensor

Generates a specified number of samples.

Arguments:

  • shape Tuple - The shape of the samples to generate.
  • device str - cpu or gpu.
  • mask Optional[Tensor] - An optional mask to apply to the samples. Defaults to None.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

Returns:

  • Float - A tensor of samples.

is_masked

def is_masked(sample: Tensor) -> Tensor

Creates a mask for whether a state is masked.

Arguments:

  • sample Tensor - The sample to check.

Returns:

  • Tensor - A float tensor indicating whether the sample is masked.

pad_sample

def pad_sample(sample: Tensor) -> Tensor

Pads the input sample with zeros along the last dimension.

Arguments:

  • sample Tensor - The input sample to be padded.

Returns:

  • Tensor - The padded sample.

bionemo.moco.distributions.prior.continuous.harmonic

LinearHarmonicPrior Objects

class LinearHarmonicPrior(PriorDistribution)

A subclass representing a Linear Harmonic prior distribution from Jing et al. https://arxiv.org/abs/2304.02198.

__init__

def __init__(length: Optional[int] = None,
             distance: Float = 3.8,
             center: Bool = False,
             rng_generator: Optional[torch.Generator] = None,
             device: Union[str, torch.device] = "cpu") -> None

Linear Harmonic prior distribution.

Arguments:

  • length Optional[int] - The number of points in a batch.
  • distance Float - RMS distance between adjacent points in the line graph.
  • center bool - Whether to center the samples around the mean. Defaults to False.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.
  • device Optional[str] - Device to place the schedule on (default is "cpu").

sample

def sample(shape: Tuple,
           mask: Optional[Tensor] = None,
           device: Union[str, torch.device] = "cpu",
           rng_generator: Optional[torch.Generator] = None) -> Tensor

Generates a specified number of samples from the Harmonic prior distribution.

Arguments:

  • shape Tuple - The shape of the samples to generate.
  • device str - cpu or gpu.
  • mask Optional[Tensor] - An optional mask to apply to the samples. Defaults to None.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

Returns:

  • Float - A tensor of samples.

bionemo.moco.distributions.prior.continuous

bionemo.moco.distributions.prior.continuous.gaussian

GaussianPrior Objects

class GaussianPrior(PriorDistribution)

A subclass representing a Gaussian prior distribution.

__init__

def __init__(mean: Float = 0.0,
             std: Float = 1.0,
             center: Bool = False,
             rng_generator: Optional[torch.Generator] = None) -> None

Gaussian prior distribution.

Arguments:

  • mean Float - The mean of the Gaussian distribution. Defaults to 0.0.
  • std Float - The standard deviation of the Gaussian distribution. Defaults to 1.0.
  • center bool - Whether to center the samples around the mean. Defaults to False.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

sample

def sample(shape: Tuple,
           mask: Optional[Tensor] = None,
           device: Union[str, torch.device] = "cpu",
           rng_generator: Optional[torch.Generator] = None) -> Tensor

Generates a specified number of samples from the Gaussian prior distribution.

Arguments:

  • shape Tuple - The shape of the samples to generate.
  • device str - cpu or gpu.
  • mask Optional[Tensor] - An optional mask to apply to the samples. Defaults to None.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

Returns:

  • Float - A tensor of samples.

bionemo.moco.distributions.prior.continuous.utils

remove_center_of_mass

def remove_center_of_mass(data: Tensor,
                          mask: Optional[Tensor] = None) -> Tensor

Calculates the center of mass (CoM) of the given data.

Arguments:

  • data - The input data with shape (..., nodes, features).
  • mask - An optional binary mask to apply to the data with shape (..., nodes) to mask out interaction from CoM calculation. Defaults to None.

Returns:

The CoM of the data with shape (..., 1, features).

bionemo.moco.distributions.prior

bionemo.moco.distributions.time.distribution

TimeDistribution Objects

class TimeDistribution(ABC)

An abstract base class representing a time distribution.

Arguments:

  • discrete_time Bool - Whether the time is discrete.
  • nsteps Optional[int] - Number of nsteps for discretization.
  • min_t Optional[Float] - Min continuous time.
  • max_t Optional[Float] - Max continuous time.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

__init__

def __init__(discrete_time: Bool = False,
             nsteps: Optional[int] = None,
             min_t: Optional[Float] = None,
             max_t: Optional[Float] = None,
             rng_generator: Optional[torch.Generator] = None)

Initializes a TimeDistribution object.

sample

@abstractmethod
def sample(n_samples: Union[int, Tuple[int, ...], torch.Size],
           device: Union[str, torch.device] = "cpu",
           rng_generator: Optional[torch.Generator] = None) -> Float

Generates a specified number of samples from the time distribution.

Arguments:

  • n_samples int - The number of samples to generate.
  • device str - cpu or gpu.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

Returns:

  • Float - A list or array of samples.

MixTimeDistribution Objects

class MixTimeDistribution()

An abstract base class representing a mixed time distribution.

uniform_dist = UniformTimeDistribution(min_t=0.0, max_t=1.0, discrete_time=False)
beta_dist = BetaTimeDistribution(min_t=0.0, max_t=1.0, discrete_time=False, p1=2.0, p2=1.0)
mix_dist = MixTimeDistribution(uniform_dist, beta_dist, mix_fraction=0.5)

__init__

def __init__(dist1: TimeDistribution, dist2: TimeDistribution,
             mix_fraction: Float)

Initializes a MixTimeDistribution object.

Arguments:

  • dist1 TimeDistribution - The first time distribution.
  • dist2 TimeDistribution - The second time distribution.
  • mix_fraction Float - The fraction of samples to draw from dist1. Must be between 0 and 1.

sample

def sample(n_samples: int,
           device: Union[str, torch.device] = "cpu",
           rng_generator: Optional[torch.Generator] = None) -> Float

Generates a specified number of samples from the mixed time distribution.

Arguments:

  • n_samples int - The number of samples to generate.
  • device str - cpu or gpu.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

Returns:

  • Float - A list or array of samples.

bionemo.moco.distributions.time.uniform

UniformTimeDistribution Objects

class UniformTimeDistribution(TimeDistribution)

A class representing a uniform time distribution.

__init__

def __init__(min_t: Float = 0.0,
             max_t: Float = 1.0,
             discrete_time: Bool = False,
             nsteps: Optional[int] = None,
             rng_generator: Optional[torch.Generator] = None)

Initializes a UniformTimeDistribution object.

Arguments:

  • min_t Float - The minimum time value.
  • max_t Float - The maximum time value.
  • discrete_time Bool - Whether the time is discrete.
  • nsteps Optional[int] - Number of nsteps for discretization.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

sample

def sample(n_samples: Union[int, Tuple[int, ...], torch.Size],
           device: Union[str, torch.device] = "cpu",
           rng_generator: Optional[torch.Generator] = None)

Generates a specified number of samples from the uniform time distribution.

Arguments:

  • n_samples int - The number of samples to generate.
  • device str - cpu or gpu.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

Returns:

A tensor of samples.

SymmetricUniformTimeDistribution Objects

class SymmetricUniformTimeDistribution(TimeDistribution)

A class representing a uniform time distribution.

__init__

def __init__(min_t: Float = 0.0,
             max_t: Float = 1.0,
             discrete_time: Bool = False,
             nsteps: Optional[int] = None,
             rng_generator: Optional[torch.Generator] = None)

Initializes a UniformTimeDistribution object.

Arguments:

  • min_t Float - The minimum time value.
  • max_t Float - The maximum time value.
  • discrete_time Bool - Whether the time is discrete.
  • nsteps Optional[int] - Number of nsteps for discretization.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

sample

def sample(n_samples: Union[int, Tuple[int, ...], torch.Size],
           device: Union[str, torch.device] = "cpu",
           rng_generator: Optional[torch.Generator] = None)

Generates a specified number of samples from the uniform time distribution.

Arguments:

  • n_samples int - The number of samples to generate.
  • device str - cpu or gpu.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

Returns:

A tensor of samples.

bionemo.moco.distributions.time.logit_normal

LogitNormalTimeDistribution Objects

class LogitNormalTimeDistribution(TimeDistribution)

A class representing a logit normal time distribution.

__init__

def __init__(p1: Float = 0.0,
             p2: Float = 1.0,
             min_t: Float = 0.0,
             max_t: Float = 1.0,
             discrete_time: Bool = False,
             nsteps: Optional[int] = None,
             rng_generator: Optional[torch.Generator] = None)

Initializes a BetaTimeDistribution object.

Arguments:

  • p1 Float - The first shape parameter of the logit normal distribution i.e. the mean.
  • p2 Float - The second shape parameter of the logit normal distribution i.e. the std.
  • min_t Float - The minimum time value.
  • max_t Float - The maximum time value.
  • discrete_time Bool - Whether the time is discrete.
  • nsteps Optional[int] - Number of nsteps for discretization.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

sample

def sample(n_samples: Union[int, Tuple[int, ...], torch.Size],
           device: Union[str, torch.device] = "cpu",
           rng_generator: Optional[torch.Generator] = None)

Generates a specified number of samples from the uniform time distribution.

Arguments:

  • n_samples int - The number of samples to generate.
  • device str - cpu or gpu.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

Returns:

A tensor of samples.

bionemo.moco.distributions.time

bionemo.moco.distributions.time.beta

BetaTimeDistribution Objects

class BetaTimeDistribution(TimeDistribution)

A class representing a beta time distribution.

__init__

def __init__(p1: Float = 2.0,
             p2: Float = 1.0,
             min_t: Float = 0.0,
             max_t: Float = 1.0,
             discrete_time: Bool = False,
             nsteps: Optional[int] = None,
             rng_generator: Optional[torch.Generator] = None)

Initializes a BetaTimeDistribution object.

Arguments:

  • p1 Float - The first shape parameter of the beta distribution.
  • p2 Float - The second shape parameter of the beta distribution.
  • min_t Float - The minimum time value.
  • max_t Float - The maximum time value.
  • discrete_time Bool - Whether the time is discrete.
  • nsteps Optional[int] - Number of nsteps for discretization.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

sample

def sample(n_samples: Union[int, Tuple[int, ...], torch.Size],
           device: Union[str, torch.device] = "cpu",
           rng_generator: Optional[torch.Generator] = None)

Generates a specified number of samples from the uniform time distribution.

Arguments:

  • n_samples int - The number of samples to generate.
  • device str - cpu or gpu.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

Returns:

A tensor of samples.

bionemo.moco.distributions.time.utils

float_time_to_index

def float_time_to_index(time: torch.Tensor,
                        num_time_steps: int) -> torch.Tensor

Convert a float time value to a time index.

Arguments:

  • time torch.Tensor - A tensor of float time values in the range [0, 1].
  • num_time_steps int - The number of discrete time steps.

Returns:

  • torch.Tensor - A tensor of time indices corresponding to the input float time values.

bionemo.moco.schedules.noise.continuous_snr_transforms

log

def log(t, eps=1e-20)

Compute the natural logarithm of a tensor, clamping values to avoid numerical instability.

Arguments:

  • t Tensor - The input tensor.
  • eps float, optional - The minimum value to clamp the input tensor (default is 1e-20).

Returns:

  • Tensor - The natural logarithm of the input tensor.

ContinuousSNRTransform Objects

class ContinuousSNRTransform(ABC)

A base class for continuous SNR schedules.

__init__

def __init__(direction: TimeDirection)

Initialize the DiscreteNoiseSchedule.

Arguments:

  • direction TimeDirection - required this defines in which direction the scheduler was built

calculate_log_snr

def calculate_log_snr(t: Tensor,
                      device: Union[str, torch.device] = "cpu",
                      synchronize: Optional[TimeDirection] = None) -> Tensor

Public wrapper to generate the time schedule as a tensor.

Arguments:

  • t Tensor - The input tensor representing the time steps, with values ranging from 0 to 1.
  • device Optional[str] - The device to place the schedule on. Defaults to "cpu".
  • synchronize optional[TimeDirection] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction,
    this parameter allows to flip the direction to match the specified one. Defaults to None.

Returns:

  • Tensor - A tensor representing the log signal-to-noise (SNR) ratio for the given time steps.

log_snr_to_alphas_sigmas

def log_snr_to_alphas_sigmas(log_snr: Tensor) -> Tuple[Tensor, Tensor]

Converts log signal-to-noise ratio (SNR) to alpha and sigma values.

Arguments:

  • log_snr Tensor - The input log SNR tensor.

Returns:

tuple[Tensor, Tensor]: A tuple containing the squared root of alpha and sigma values.

derivative

def derivative(t: Tensor, func: Callable) -> Tensor

Compute derivative of a function, it supports bached single variable inputs.

Arguments:

  • t Tensor - time variable at which derivatives are taken
  • func Callable - function for derivative calculation

Returns:

  • Tensor - derivative that is detached from the computational graph

calculate_general_sde_terms

def calculate_general_sde_terms(t)

Compute the general SDE terms for a given time step t.

Arguments:

  • t Tensor - The input tensor representing the time step.

Returns:

tuple[Tensor, Tensor]: A tuple containing the drift term f_t and the diffusion term g_t_2.

Notes:

This method computes the drift and diffusion terms of the general SDE, which can be used to simulate the stochastic process.
The drift term represents the deterministic part of the process, while the diffusion term represents the stochastic part.

calculate_beta

def calculate_beta(t)

Compute the drift coefficient for the OU process of the form $dx = -\frac{1}{2} \beta(t) x dt + sqrt(beta(t)) dw_t$.

beta = d/dt log(alpha**2) = 2 * 1/alpha * d/dt(alpha)

Arguments:

  • t Union[float, Tensor] - t in [0, 1]

Returns:

  • Tensor - beta(t)

calculate_alpha_log_snr

def calculate_alpha_log_snr(log_snr: Tensor) -> Tensor

Compute alpha values based on the log SNR.

Arguments:

  • log_snr Tensor - The input tensor representing the log signal-to-noise ratio.

Returns:

  • Tensor - A tensor representing the alpha values for the given log SNR.

Notes:

This method computes alpha values as the square root of the sigmoid of the log SNR.

calculate_alpha_t

def calculate_alpha_t(t: Tensor) -> Tensor

Compute alpha values based on the log SNR schedule.

Arguments:

  • t Tensor - The input tensor representing the time steps.

Returns:

  • Tensor - A tensor representing the alpha values for the given time steps.

Notes:

This method computes alpha values as the square root of the sigmoid of the log SNR.

CosineSNRTransform Objects

class CosineSNRTransform(ContinuousSNRTransform)

A cosine SNR schedule.

Arguments:

  • nu Optional[Float] - Hyperparameter for the cosine schedule exponent (default is 1.0).
  • s Optional[Float] - Hyperparameter for the cosine schedule shift (default is 0.008).

__init__

def __init__(nu: Float = 1.0, s: Float = 0.008)

Initialize the CosineNoiseSchedule.

LinearSNRTransform Objects

class LinearSNRTransform(ContinuousSNRTransform)

A Linear SNR schedule.

__init__

def __init__(min_value: Float = 1.0e-4)

Initialize the Linear SNR Transform.

Arguments:

  • min_value Float - min vaue of SNR defaults to 1.e-4.

LinearLogInterpolatedSNRTransform Objects

class LinearLogInterpolatedSNRTransform(ContinuousSNRTransform)

A Linear Log space interpolated SNR schedule.

__init__

def __init__(min_value: Float = -7.0, max_value=13.5)

Initialize the Linear log space interpolated SNR Schedule from Chroma.

Arguments:

  • min_value Float - The min log SNR value.
  • max_value Float - the max log SNR value.

bionemo.moco.schedules.noise.discrete_noise_schedules

DiscreteNoiseSchedule Objects

class DiscreteNoiseSchedule(ABC)

A base class for discrete noise schedules.

__init__

def __init__(nsteps: int, direction: TimeDirection)

Initialize the DiscreteNoiseSchedule.

Arguments:

  • nsteps int - number of discrete steps.
  • direction TimeDirection - required this defines in which direction the scheduler was built

generate_schedule

def generate_schedule(nsteps: Optional[int] = None,
                      device: Union[str, torch.device] = "cpu",
                      synchronize: Optional[TimeDirection] = None) -> Tensor

Generate the noise schedule as a tensor.

Arguments:

  • nsteps Optional[int] - Number of time steps. If None, uses the value from initialization.
  • device Optional[str] - Device to place the schedule on (default is "cpu").
  • synchronize Optional[str] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction,
    this parameter allows to flip the direction to match the specified one (default is None).

calculate_derivative

def calculate_derivative(
        nsteps: Optional[int] = None,
        device: Union[str, torch.device] = "cpu",
        synchronize: Optional[TimeDirection] = None) -> Tensor

Calculate the time derivative of the schedule.

Arguments:

  • nsteps Optional[int] - Number of time steps. If None, uses the value from initialization.
  • device Optional[str] - Device to place the schedule on (default is "cpu").
  • synchronize Optional[str] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction,
    this parameter allows to flip the direction to match the specified one (default is None).

Returns:

  • Tensor - A tensor representing the time derivative of the schedule.

Raises:

  • NotImplementedError - If the derivative calculation is not implemented for this schedule.

DiscreteCosineNoiseSchedule Objects

class DiscreteCosineNoiseSchedule(DiscreteNoiseSchedule)

A cosine discrete noise schedule.

__init__

def __init__(nsteps: int, nu: Float = 1.0, s: Float = 0.008)

Initialize the CosineNoiseSchedule.

Arguments:

  • nsteps int - Number of discrete steps.
  • nu Optional[Float] - Hyperparameter for the cosine schedule exponent (default is 1.0).
  • s Optional[Float] - Hyperparameter for the cosine schedule shift (default is 0.008).

DiscreteLinearNoiseSchedule Objects

class DiscreteLinearNoiseSchedule(DiscreteNoiseSchedule)

A linear discrete noise schedule.

__init__

def __init__(nsteps: int, beta_start: Float = 1e-4, beta_end: Float = 0.02)

Initialize the CosineNoiseSchedule.

Arguments:

  • nsteps Optional[int] - Number of time steps. If None, uses the value from initialization.
  • beta_start Optional[int] - starting beta value. Defaults to 1e-4.
  • beta_end Optional[int] - end beta value. Defaults to 0.02.

bionemo.moco.schedules.noise

bionemo.moco.schedules.noise.continuous_noise_transforms

ContinuousExpNoiseTransform Objects

class ContinuousExpNoiseTransform(ABC)

A base class for continuous schedules.

alpha = exp(- sigma) where 1 - alpha controls the masking fraction.

__init__

def __init__(direction: TimeDirection)

Initialize the DiscreteNoiseSchedule.

Arguments:

direction : TimeDirection, required this defines in which direction the scheduler was built

calculate_sigma

def calculate_sigma(t: Tensor,
                    device: Union[str, torch.device] = "cpu",
                    synchronize: Optional[TimeDirection] = None) -> Tensor

Calculate the sigma for the given time steps.

Arguments:

  • t Tensor - The input tensor representing the time steps, with values ranging from 0 to 1.
  • device Optional[str] - The device to place the schedule on. Defaults to "cpu".
  • synchronize optional[TimeDirection] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction,
    this parameter allows to flip the direction to match the specified one. Defaults to None.

Returns:

  • Tensor - A tensor representing the sigma values for the given time steps.

Raises:

  • ValueError - If the input time steps exceed the maximum allowed value of 1.

sigma_to_alpha

def sigma_to_alpha(sigma: Tensor) -> Tensor

Converts sigma to alpha values by alpha = exp(- sigma).

Arguments:

  • sigma Tensor - The input sigma tensor.

Returns:

  • Tensor - A tensor containing the alpha values.

CosineExpNoiseTransform Objects

class CosineExpNoiseTransform(ContinuousExpNoiseTransform)

A cosine Exponential noise schedule.

__init__

def __init__(eps: Float = 1.0e-3)

Initialize the CosineNoiseSchedule.

Arguments:

  • eps Float - small number to prevent numerical issues.

d_dt_sigma

def d_dt_sigma(t: Tensor, device: Union[str, torch.device] = "cpu") -> Tensor

Compute the derivative of sigma with respect to time.

Arguments:

  • t Tensor - The input tensor representing the time steps.
  • device Optional[str] - The device to place the schedule on. Defaults to "cpu".

Returns:

  • Tensor - A tensor representing the derivative of sigma with respect to time.

Notes:

The derivative of sigma as a function of time is given by:

d/dt sigma(t) = d/dt (-log(cos(t * pi / 2) + eps))

Using the chain rule, we get:

d/dt sigma(t) = (-1 / (cos(t * pi / 2) + eps)) * (-sin(t * pi / 2) * pi / 2)

This is the derivative that is computed and returned by this method.

LogLinearExpNoiseTransform Objects

class LogLinearExpNoiseTransform(ContinuousExpNoiseTransform)

A log linear exponential schedule.

__init__

def __init__(eps: Float = 1.0e-3)

Initialize the CosineNoiseSchedule.

Arguments:

  • eps Float - small value to prevent numerical issues.

d_dt_sigma

def d_dt_sigma(t: Tensor, device: Union[str, torch.device] = "cpu") -> Tensor

Compute the derivative of sigma with respect to time.

Arguments:

  • t Tensor - The input tensor representing the time steps.
  • device Optional[str] - The device to place the schedule on. Defaults to "cpu".

Returns:

  • Tensor - A tensor representing the derivative of sigma with respect to time.

bionemo.moco.schedules

bionemo.moco.schedules.utils

TimeDirection Objects

class TimeDirection(Enum)

Enum for the direction of the noise schedule.

UNIFIED

Noise(0) --> Data(1)

DIFFUSION

Noise(1) --> Data(0)

bionemo.moco.schedules.inference_time_schedules

InferenceSchedule Objects

class InferenceSchedule(ABC)

A base class for inference time schedules.

__init__

def __init__(nsteps: int,
             min_t: Float = 0,
             padding: Float = 0,
             dilation: Float = 0,
             direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
             device: Union[str, torch.device] = "cpu")

Initialize the InferenceSchedule.

Arguments:

  • nsteps int - Number of time steps.
  • min_t Float - minimum time value defaults to 0.
  • padding Float - padding time value defaults to 0.
  • dilation Float - dilation time value defaults to 0 ie the number of replicates.
  • direction Optional[str] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).
  • device Optional[str] - Device to place the schedule on (default is "cpu").

generate_schedule

@abstractmethod
def generate_schedule(
        nsteps: Optional[int] = None,
        device: Optional[Union[str, torch.device]] = None) -> Tensor

Generate the time schedule as a tensor.

Arguments:

  • nsteps Optioanl[int] - Number of time steps. If None, uses the value from initialization.
  • device Optional[str] - Device to place the schedule on (default is "cpu").

pad_time

def pad_time(n_samples: int,
             scalar_time: Float,
             device: Optional[Union[str, torch.device]] = None) -> Tensor

Creates a tensor of shape (n_samples,) filled with a scalar time value.

Arguments:

  • n_samples int - The desired dimension of the output tensor.
  • scalar_time Float - The scalar time value to fill the tensor with.
    device (Optional[Union[str, torch.device]], optional):
    The device to place the tensor on. Defaults to None, which uses the default device.

Returns:

  • Tensor - A tensor of shape (n_samples,) filled with the scalar time value.

ContinuousInferenceSchedule Objects

class ContinuousInferenceSchedule(InferenceSchedule)

A base class for continuous time inference schedules.

__init__

def __init__(nsteps: int,
             inclusive_end: bool = False,
             min_t: Float = 0,
             padding: Float = 0,
             dilation: Float = 0,
             direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
             device: Union[str, torch.device] = "cpu")

Initialize the ContinuousInferenceSchedule.

Arguments:

  • nsteps int - Number of time steps.
  • inclusive_end bool - If True, include the end value (1.0) in the schedule otherwise ends at 1.0-1/nsteps (default is False).
  • min_t Float - minimum time value defaults to 0.
  • padding Float - padding time value defaults to 0.
  • dilation Float - dilation time value defaults to 0 ie the number of replicates.
  • direction Optional[str] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).
  • device Optional[str] - Device to place the schedule on (default is "cpu").

discretize

def discretize(nsteps: Optional[int] = None,
               schedule: Optional[Tensor] = None,
               device: Optional[Union[str, torch.device]] = None) -> Tensor

Discretize the time schedule into a list of time deltas.

Arguments:

  • nsteps Optioanl[int] - Number of time steps. If None, uses the value from initialization.
  • schedule Optional[Tensor] - Time scheudle if None will generate it with generate_schedule.
  • device Optional[str] - Device to place the schedule on (default is "cpu").

Returns:

  • Tensor - A tensor of time deltas.

DiscreteInferenceSchedule Objects

class DiscreteInferenceSchedule(InferenceSchedule)

A base class for discrete time inference schedules.

discretize

def discretize(nsteps: Optional[int] = None,
               device: Optional[Union[str, torch.device]] = None) -> Tensor

Discretize the time schedule into a list of time deltas.

Arguments:

  • nsteps Optioanl[int] - Number of time steps. If None, uses the value from initialization.
  • device Optional[str] - Device to place the schedule on (default is "cpu").

Returns:

  • Tensor - A tensor of time deltas.

DiscreteLinearInferenceSchedule Objects

class DiscreteLinearInferenceSchedule(DiscreteInferenceSchedule)

A linear time schedule for discrete time inference.

__init__

def __init__(nsteps: int,
             min_t: Float = 0,
             padding: Float = 0,
             dilation: Float = 0,
             direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
             device: Union[str, torch.device] = "cpu")

Initialize the DiscreteLinearInferenceSchedule.

Arguments:

  • nsteps int - Number of time steps.
  • min_t Float - minimum time value defaults to 0.
  • padding Float - padding time value defaults to 0.
  • dilation Float - dilation time value defaults to 0 ie the number of replicates.
  • direction Optional[str] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).
  • device Optional[str] - Device to place the schedule on (default is "cpu").

generate_schedule

def generate_schedule(
        nsteps: Optional[int] = None,
        device: Optional[Union[str, torch.device]] = None) -> Tensor

Generate the linear time schedule as a tensor.

Arguments:

  • nsteps Optional[int] - Number of time steps. If None uses the value from initialization.
  • device Optional[str] - Device to place the schedule on (default is "cpu").

Returns:

  • Tensor - A tensor of time steps.
  • Tensor - A tensor of time steps.

LinearInferenceSchedule Objects

class LinearInferenceSchedule(ContinuousInferenceSchedule)

A linear time schedule for continuous time inference.

__init__

def __init__(nsteps: int,
             inclusive_end: bool = False,
             min_t: Float = 0,
             padding: Float = 0,
             dilation: Float = 0,
             direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
             device: Union[str, torch.device] = "cpu")

Initialize the LinearInferenceSchedule.

Arguments:

  • nsteps int - Number of time steps.
  • inclusive_end bool - If True, include the end value (1.0) in the schedule otherwise ends at 1.0-1/nsteps (default is False).
  • min_t Float - minimum time value defaults to 0.
  • padding Float - padding time value defaults to 0.
  • dilation Float - dilation time value defaults to 0 ie the number of replicates.
  • direction Optional[str] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).
  • device Optional[str] - Device to place the schedule on (default is "cpu").

generate_schedule

def generate_schedule(
        nsteps: Optional[int] = None,
        device: Optional[Union[str, torch.device]] = None) -> Tensor

Generate the linear time schedule as a tensor.

Arguments:

  • nsteps Optional[int] - Number of time steps. If None uses the value from initialization.
  • device Optional[str] - Device to place the schedule on (default is "cpu").

Returns:

  • Tensor - A tensor of time steps.

PowerInferenceSchedule Objects

class PowerInferenceSchedule(ContinuousInferenceSchedule)

A power time schedule for inference, where time steps are generated by raising a uniform schedule to a specified power.

__init__

def __init__(nsteps: int,
             inclusive_end: bool = False,
             min_t: Float = 0,
             padding: Float = 0,
             dilation: Float = 0,
             exponent: Float = 1.0,
             direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
             device: Union[str, torch.device] = "cpu")

Initialize the PowerInferenceSchedule.

Arguments:

  • nsteps int - Number of time steps.
  • inclusive_end bool - If True, include the end value (1.0) in the schedule otherwise ends at <1.0 (default is False).
  • min_t Float - minimum time value defaults to 0.
  • padding Float - padding time value defaults to 0.
  • dilation Float - dilation time value defaults to 0 ie the number of replicates.
  • exponent Float - Power parameter defaults to 1.0.
  • direction Optional[str] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).
  • device Optional[str] - Device to place the schedule on (default is "cpu").

generate_schedule

def generate_schedule(
        nsteps: Optional[int] = None,
        device: Optional[Union[str, torch.device]] = None) -> Tensor

Generate the power time schedule as a tensor.

Arguments:

  • nsteps Optional[int] - Number of time steps. If None uses the value from initialization.
  • device Optional[str] - Device to place the schedule on (default is "cpu").

Returns:

  • Tensor - A tensor of time steps.
  • Tensor - A tensor of time steps.

LogInferenceSchedule Objects

class LogInferenceSchedule(ContinuousInferenceSchedule)

A log time schedule for inference, where time steps are generated by taking the logarithm of a uniform schedule.

__init__

def __init__(nsteps: int,
             inclusive_end: bool = False,
             min_t: Float = 0,
             padding: Float = 0,
             dilation: Float = 0,
             exponent: Float = -2.0,
             direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
             device: Union[str, torch.device] = "cpu")

Initialize the LogInferenceSchedule.

Returns a log space time schedule.

Which for 100 steps with default parameters is:
tensor([0.0000, 0.0455, 0.0889, 0.1303, 0.1699, 0.2077, 0.2439, 0.2783, 0.3113,
0.3427, 0.3728, 0.4015, 0.4288, 0.4550, 0.4800, 0.5039, 0.5266, 0.5484,
0.5692, 0.5890, 0.6080, 0.6261, 0.6434, 0.6599, 0.6756, 0.6907, 0.7051,
0.7188, 0.7319, 0.7444, 0.7564, 0.7678, 0.7787, 0.7891, 0.7991, 0.8086,
0.8176, 0.8263, 0.8346, 0.8425, 0.8500, 0.8572, 0.8641, 0.8707, 0.8769,
0.8829, 0.8887, 0.8941, 0.8993, 0.9043, 0.9091, 0.9136, 0.9180, 0.9221,
0.9261, 0.9299, 0.9335, 0.9369, 0.9402, 0.9434, 0.9464, 0.9492, 0.9520,
0.9546, 0.9571, 0.9595, 0.9618, 0.9639, 0.9660, 0.9680, 0.9699, 0.9717,
0.9734, 0.9751, 0.9767, 0.9782, 0.9796, 0.9810, 0.9823, 0.9835, 0.9847,
0.9859, 0.9870, 0.9880, 0.9890, 0.9899, 0.9909, 0.9917, 0.9925, 0.9933,
0.9941, 0.9948, 0.9955, 0.9962, 0.9968, 0.9974, 0.9980, 0.9985, 0.9990,
0.9995]
)

Arguments:

  • nsteps int - Number of time steps.
  • inclusive_end bool - If True, include the end value (1.0) in the schedule otherwise ends at <1.0 (default is False).
  • min_t Float - minimum time value defaults to 0.
  • padding Float - padding time value defaults to 0.
  • dilation Float - dilation time value defaults to 0 ie the number of replicates.
  • exponent Float - log space exponent parameter defaults to -2.0. The lower number the more aggressive the acceleration of 0 to 0.9 will be thus having more steps from 0.9 to 1.0.
  • direction Optional[str] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).
  • device Optional[str] - Device to place the schedule on (default is "cpu").

generate_schedule

def generate_schedule(
        nsteps: Optional[int] = None,
        device: Optional[Union[str, torch.device]] = None) -> Tensor

Generate the log time schedule as a tensor.

Arguments:

  • nsteps Optional[int] - Number of time steps. If None uses the value from initialization.
  • device Optional[str] - Device to place the schedule on (default is "cpu").

bionemo.moco.interpolants.continuous_time.discrete

bionemo.moco.interpolants.continuous_time.discrete.mdlm

MDLM Objects

class MDLM(Interpolant)

A Masked discrete Diffusion Language Model (MDLM) interpolant.


Examples:

>>> import torch
>>> from bionemo.bionemo.moco.distributions.prior.discrete.mask import DiscreteMaskedPrior
>>> from bionemo.bionemo.moco.distributions.time.uniform import UniformTimeDistribution
>>> from bionemo.bionemo.moco.interpolants.continuous_time.discrete.mdlm import MDLM
>>> from bionemo.bionemo.moco.schedules.noise.continuous_noise_transforms import CosineExpNoiseTransform
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import LinearTimeSchedule


mdlm = MDLM(
    time_distribution = UniformTimeDistribution(discrete_time = False,...),
    prior_distribution = DiscreteMaskedPrior(...),
    noise_schedule = CosineExpNoiseTransform(...),
    )
model = Model(...)

# Training
for epoch in range(1000):
    data = data_loader.get(...)
    time = mdlm.sample_time(batch_size)
    xt = mdlm.interpolate(data, time)

    logits = model(xt, time)
    loss = mdlm.loss(logits, data, xt, time)
    loss.backward()

# Generation
x_pred = mdlm.sample_prior(data.shape)
schedule = LinearTimeSchedule(...)
inference_time = schedule.generate_schedule()
dts = schedue.discreteize()
for t, dt in zip(inference_time, dts):
    time = torch.full((batch_size,), t)
    logits = model(x_pred, time)
    x_pred = mdlm.step(logits, time, x_pred, dt)
return x_pred

__init__

def __init__(time_distribution: TimeDistribution,
             prior_distribution: DiscreteMaskedPrior,
             noise_schedule: ContinuousExpNoiseTransform,
             device: str = "cpu",
             rng_generator: Optional[torch.Generator] = None)

Initialize the Masked Discrete Language Model (MDLM) interpolant.

Arguments:

  • time_distribution TimeDistribution - The distribution governing the time variable in the diffusion process.
  • prior_distribution DiscreteMaskedPrior - The prior distribution over the discrete token space, including masked tokens.
  • noise_schedule ContinuousExpNoiseTransform - The noise schedule defining the noise intensity as a function of time.
  • device str, optional - The device to use for computations. Defaults to "cpu".
  • rng_generator Optional[torch.Generator], optional - The random number generator for reproducibility. Defaults to None.

interpolate

def interpolate(data: Tensor, t: Tensor)

Get x(t) with given time t from noise and data.

Arguments:

  • data Tensor - target discrete ids
  • t Tensor - time

forward_process

def forward_process(data: Tensor, t: Tensor) -> Tensor

Apply the forward process to the data at time t.

Arguments:

  • data Tensor - target discrete ids
  • t Tensor - time

Returns:

  • Tensor - x(t) after applying the forward process

loss

def loss(logits: Tensor,
         target: Tensor,
         xt: Tensor,
         time: Tensor,
         mask: Optional[Tensor] = None,
         use_weight=True)

Calculate the cross-entropy loss between the model prediction and the target output.

The loss is calculated between the batch x node x class logits and the target batch x node,
considering the current state of the discrete sequence xt at time time.

If use_weight is True, the loss is weighted by the reduced form of the MDLM time weight for continuous NELBO,
as specified in equation 11 of https://arxiv.org/pdf/2406.07524. This weight is proportional to the derivative
of the noise schedule with respect to time, and is used to emphasize the importance of accurate predictions at
certain times in the diffusion process.

Arguments:

  • logits Tensor - The predicted output from the model, with shape batch x node x class.
  • target Tensor - The target output for the model prediction, with shape batch x node.
  • xt Tensor - The current state of the discrete sequence, with shape batch x node.
  • time Tensor - The time at which the loss is calculated.
  • mask Optional[Tensor], optional - The mask for the data point. Defaults to None.
  • use_weight bool, optional - Whether to use the MDLM time weight for the loss. Defaults to True.

Returns:

  • Tensor - The calculated loss batch tensor.

step

def step(logits: Tensor,
         t: Tensor,
         xt: Tensor,
         dt: Tensor,
         temperature: float = 1.0) -> Tensor

Perform a single step of MDLM DDPM step.

Arguments:

  • logits Tensor - The input logits.
  • t Tensor - The current time step.
  • xt Tensor - The current state.
  • dt Tensor - The time step increment.
  • temperature float - Softmax temperature defaults to 1.0.

Returns:

  • Tensor - The updated state.

get_num_steps_confidence

def get_num_steps_confidence(xt: Tensor, num_tokens_unmask: int = 1)

Calculate the maximum number of steps with confidence.

This method computes the maximum count of occurrences where the input tensor xt matches the mask_index
along the last dimension (-1). The result is returned as a single float value.

Arguments:

  • xt Tensor - Input tensor to evaluate against the mask index.
  • num_tokens_unmask int - number of tokens to unamsk at each step.

Returns:

  • float - The maximum number of steps with confidence (i.e., matching the mask index).

step_confidence

def step_confidence(logits: Tensor,
                    xt: Tensor,
                    curr_step: int,
                    num_steps: int,
                    logit_temperature: float = 1.0,
                    randomness: float = 1.0,
                    confidence_temperature: float = 1.0,
                    num_tokens_unmask: int = 1) -> Tensor

Update the input sequence xt by sampling from the predicted logits and adding Gumbel noise.

Method taken from GenMol Lee et al. https://arxiv.org/abs/2501.06158

Arguments:

  • logits - Predicted logits
  • xt - Input sequence
  • curr_step - Current step
  • num_steps - Total number of steps
  • logit_temperature - Temperature for softmax over logits
  • randomness - Scale for Gumbel noise
  • confidence_temperature - Temperature for Gumbel confidence
  • num_tokens_unmask - number of tokens to unmask each step

Returns:

Updated input sequence xt unmasking num_tokens_unmask token each step.

step_argmax

def step_argmax(model_out: Tensor)

Returns the index of the maximum value in the last dimension of the model output.

Arguments:

  • model_out Tensor - The output of the model.

Returns:

  • Tensor - The index of the maximum value in the last dimension of the model output.

calculate_score

def calculate_score(logits, x, t)

Returns score of the given sample x at time t with the corresponding model output logits.

Arguments:

  • logits Tensor - The output of the model.
  • x Tensor - The current data point.
  • t Tensor - The current time.

Returns:

  • Tensor - The score defined in Appendix C.3 Equation 76 of MDLM.

step_self_path_planning

def step_self_path_planning(logits: Tensor,
                            xt: Tensor,
                            t: Tensor,
                            curr_step: int,
                            num_steps: int,
                            logit_temperature: float = 1.0,
                            randomness: float = 1.0,
                            confidence_temperature: float = 1.0,
                            score_type: Literal["confidence",
                                                "random"] = "confidence",
                            fix_mask: Optional[Tensor] = None) -> Tensor

Self Path Planning (P2) Sampling from Peng et al. https://arxiv.org/html/2502.03540v1.

Arguments:

  • logits Tensor - Predicted logits for sampling.
  • xt Tensor - Input sequence to be updated.
  • t Tensor - Time tensor (e.g., time steps or temporal info).
  • curr_step int - Current iteration in the planning process.
  • num_steps int - Total number of planning steps.
  • logit_temperature float - Temperature for logits (default: 1.0).
  • randomness float - Introduced randomness level (default: 1.0).
  • confidence_temperature float - Temperature for confidence scoring (default: 1.0).
  • score_type Literal["confidence", "random"] - Sampling score type (default: "confidence").
  • fix_mask Optional[Tensor] - inital mask where True when not a mask tokens (default: None).

Returns:

  • Tensor - Updated input sequence xt after iterative unmasking.

topk_lowest_masking

def topk_lowest_masking(scores: Tensor, cutoff_len: Tensor)

Generates a mask for the lowest scoring elements up to a specified cutoff length.

Arguments:

  • scores Tensor - Input scores tensor with shape (... , num_elements)
  • cutoff_len Tensor - Number of lowest-scoring elements to mask (per batch element)

Returns:

  • Tensor - Boolean mask tensor with same shape as scores, where True indicates
    the corresponding element is among the cutoff_len lowest scores.

Example:

scores = torch.tensor([[0.9, 0.8, 0.1, 0.05], [0.7, 0.4, 0.3, 0.2]])
cutoff_len = 2
mask = topk_lowest_masking(scores, cutoff_len)
print(mask)
tensor([[False, False, True, True],
[False, True, True, False]])

stochastic_sample_from_categorical

def stochastic_sample_from_categorical(logits: Tensor,
                                       temperature: float = 1.0,
                                       noise_scale: float = 1.0)

Stochastically samples from a categorical distribution defined by input logits, with optional temperature and noise scaling for diverse sampling.

Arguments:

  • logits Tensor - Input logits tensor with shape (... , num_categories)
  • temperature float, optional - Softmax temperature. Higher values produce more uniform samples. Defaults to 1.0.
  • noise_scale float, optional - Scale for Gumbel noise. Higher values produce more diverse samples. Defaults to 1.0.

Returns:

tuple:
- tokens (LongTensor): Sampling result (category indices) with shape (... , )
- scores (Tensor): Corresponding log-softmax scores for the sampled tokens, with shape (... , )

bionemo.moco.interpolants.continuous_time.discrete.discrete_flow_matching

DiscreteFlowMatcher Objects

class DiscreteFlowMatcher(Interpolant)

A Discrete Flow Model (DFM) interpolant.

__init__

def __init__(time_distribution: TimeDistribution,
             prior_distribution: DiscretePriorDistribution,
             device: str = "cpu",
             eps: Float = 1e-5,
             rng_generator: Optional[torch.Generator] = None)

Initialize the DFM interpolant.

Arguments:

  • time_distribution TimeDistribution - The time distribution for the diffusion process.
  • prior_distribution DiscretePriorDistribution - The prior distribution for the discrete masked tokens.
  • device str, optional - The device to use for computations. Defaults to "cpu".
  • eps - small Float to prevent dividing by zero.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

interpolate

def interpolate(data: Tensor, t: Tensor, noise: Tensor)

Get x(t) with given time t from noise and data.

Arguments:

  • data Tensor - target discrete ids
  • t Tensor - time
  • noise - tensor noise ids

loss

def loss(logits: Tensor,
         target: Tensor,
         time: Optional[Tensor] = None,
         mask: Optional[Tensor] = None,
         use_weight: Bool = False)

Calculate the cross-entropy loss between the model prediction and the target output.

The loss is calculated between the batch x node x class logits and the target batch x node.
If using a masked prior please pass in the correct mask to calculate loss values on only masked states.
i.e. mask = data_mask * is_masked_state which is calculated with self.prior_dist.is_masked(xt))

If use_weight is True, the loss is weighted by 1/(1-t) defined in equation 24 in Appndix C. of https://arxiv.org/pdf/2402.04997

Arguments:

  • logits Tensor - The predicted output from the model, with shape batch x node x class.
  • target Tensor - The target output for the model prediction, with shape batch x node.
  • time Tensor - The time at which the loss is calculated.
  • mask Optional[Tensor], optional - The mask for the data point. Defaults to None.
  • use_weight bool, optional - Whether to use the DFM time weight for the loss. Defaults to True.

Returns:

  • Tensor - The calculated loss batch tensor.

step

def step(logits: Tensor,
         t: Tensor,
         xt: Tensor,
         dt: Tensor | float,
         temperature: Float = 1.0,
         stochasticity: Float = 1.0) -> Tensor

Perform a single step of DFM euler updates.

Arguments:

  • logits Tensor - The input logits.
  • t Tensor - The current time step.
  • xt Tensor - The current state.
  • dt Tensor | float - The time step increment.
  • temperature Float, optional - The temperature for the softmax calculation. Defaults to 1.0.
  • stochasticity Float, optional - The stochasticity value for the step calculation. Defaults to 1.0.

Returns:

  • Tensor - The updated state.

step_purity

def step_purity(logits: Tensor,
                t: Tensor,
                xt: Tensor,
                dt: Tensor | float,
                temperature: Float = 1.0,
                stochasticity: Float = 1.0) -> Tensor

Perform a single step of purity sampling.

https://github.com/jasonkyuyim/multiflow/blob/6278899970523bad29953047e7a42b32a41dc813/multiflow/data/interpolant.py#L346
Here's a high-level overview of what the function does:
TODO: check if the -1e9 and 1e-9 are small enough or using torch.inf would be better

  1. Preprocessing:
    Checks if dt is a float and converts it to a tensor if necessary.
    Pads t and dt to match the shape of xt.
    Checks if the mask_index is valid (i.e., within the range of possible discrete values).
  2. Masking:
    Sets the logits corresponding to the mask_index to a low value (-1e9) to effectively mask out those values.
    Computes the softmax probabilities of the logits.
    Sets the probability of the mask_index to a small value (1e-9) to avoid numerical issues.
    3.Purity sampling:
    Computes the maximum log probabilities of the softmax distribution.
    Computes the indices of the top-number_to_unmask samples with the highest log probabilities.
    Uses these indices to sample new values from the original distribution.
  3. Unmasking and updating:
    Creates a mask to select the top-number_to_unmask samples.
    Uses this mask to update the current state xt with the new samples.
  4. Re-masking:
    Generates a new mask to randomly re-mask some of the updated samples.
    Applies this mask to the updated state xt.

Arguments:

  • logits Tensor - The input logits.
  • t Tensor - The current time step.
  • xt Tensor - The current state.
  • dt Tensor - The time step increment.
  • temperature Float, optional - The temperature for the softmax calculation. Defaults to 1.0.
  • stochasticity Float, optional - The stochasticity value for the step calculation. Defaults to 1.0.

Returns:

  • Tensor - The updated state.

step_argmax

def step_argmax(model_out: Tensor)

Returns the index of the maximum value in the last dimension of the model output.

Arguments:

  • model_out Tensor - The output of the model.

step_simple_sample

def step_simple_sample(model_out: Tensor,
                       temperature: float = 1.0,
                       num_samples: int = 1)

Samples from the model output logits. Leads to more diversity than step_argmax.

Arguments:

  • model_out Tensor - The output of the model.
  • temperature Float, optional - The temperature for the softmax calculation. Defaults to 1.0.
  • num_samples int - Number of samples to return

bionemo.moco.interpolants.continuous_time.continuous.data_augmentation.ot_sampler

OTSampler Objects

class OTSampler()

Sampler for Exact Mini-batch Optimal Transport Plan.

OTSampler implements sampling coordinates according to an OT plan (wrt squared Euclidean cost)
with different implementations of the plan calculation. Code is adapted from https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/optimal_transport.py

__init__

def __init__(method: str = "exact",
             device: Union[str, torch.device] = "cpu",
             num_threads: int = 1) -> None

Initialize the OTSampler class.

Arguments:

  • method str - Choose which optimal transport solver you would like to use. Currently only support exact OT solvers (pot.emd).
  • device Union[str, torch.device], optional - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
  • num_threads Union[int, str], optional - Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1.

Raises:

  • ValueError - If the OT solver is not documented.
  • NotImplementedError - If the OT solver is not implemented.

to_device

def to_device(device: str)

Moves all internal tensors to the specified device and updates the self.device attribute.

Arguments:

  • device str - The device to move the tensors to (e.g. "cpu", "cuda:0").

Notes:

This method is used to transfer the internal state of the OTSampler to a different device.
It updates the self.device attribute to reflect the new device and moves all internal tensors to the specified device.

sample_map

def sample_map(pi: Tensor,
               batch_size: int,
               replace: Bool = False) -> Tuple[Tensor, Tensor]

Draw source and target samples from pi $(x,z) \sim \pi$.

Arguments:

  • pi Tensor - shape (bs, bs), the OT matrix between noise and data in minibatch.
  • batch_size int - The batch size of the minibatch.
  • replace bool - sampling w/ or w/o replacement from the OT plan, default to False.

Returns:

  • Tuple - tuple of 2 tensors, represents the indices of noise and data samples from pi.

get_ot_matrix

def get_ot_matrix(x0: Tensor,
                  x1: Tensor,
                  mask: Optional[Tensor] = None) -> Tensor

Compute the OT matrix between a source and a target minibatch.

Arguments:

  • x0 Tensor - shape (bs, *dim), noise from source minibatch.
  • x1 Tensor - shape (bs, *dim), data from source minibatch.
  • mask Optional[Tensor], optional - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.

Returns:

  • p Tensor - shape (bs, bs), the OT matrix between noise and data in minibatch.

apply_augmentation

def apply_augmentation(
    x0: Tensor,
    x1: Tensor,
    mask: Optional[Tensor] = None,
    replace: Bool = False,
    sort: Optional[Literal["noise", "x0", "data", "x1"]] = "x0"
) -> Tuple[Tensor, Tensor, Optional[Tensor]]

Sample indices for noise and data in minibatch according to OT plan.

Compute the OT plan $\pi$ (wrt squared Euclidean cost) between a source and a target
minibatch and draw source and target samples from pi $(x,z) \sim \pi$.

Arguments:

  • x0 Tensor - shape (bs, *dim), noise from source minibatch.
  • x1 Tensor - shape (bs, *dim), data from source minibatch.
  • mask Optional[Tensor], optional - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.
  • replace bool - sampling w/ or w/o replacement from the OT plan, default to False.
  • sort str - Optional Literal string to sort either x1 or x0 based on the input.

Returns:

  • Tuple - tuple of 2 tensors or 3 tensors if mask is used, represents the noise (plus mask) and data samples following OT plan pi.

bionemo.moco.interpolants.continuous_time.continuous.data_augmentation.equivariant_ot_sampler

EquivariantOTSampler Objects

class EquivariantOTSampler()

Sampler for Mini-batch Optimal Transport Plan with cost calculated after Kabsch alignment.

EquivariantOTSampler implements sampling coordinates according to an OT plan
(wrt squared Euclidean cost after Kabsch alignment) with different implementations of the plan calculation.

__init__

def __init__(method: str = "exact",
             device: Union[str, torch.device] = "cpu",
             num_threads: int = 1) -> None

Initialize the OTSampler class.

Arguments:

  • method str - Choose which optimal transport solver you would like to use. Currently only support exact OT solvers (pot.emd).
  • device Union[str, torch.device], optional - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
  • num_threads Union[int, str], optional - Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1.

Raises:

  • ValueError - If the OT solver is not documented.
  • NotImplementedError - If the OT solver is not implemented.

to_device

def to_device(device: str)

Moves all internal tensors to the specified device and updates the self.device attribute.

Arguments:

  • device str - The device to move the tensors to (e.g. "cpu", "cuda:0").

Notes:

This method is used to transfer the internal state of the OTSampler to a different device.
It updates the self.device attribute to reflect the new device and moves all internal tensors to the specified device.

sample_map

def sample_map(pi: Tensor,
               batch_size: int,
               replace: Bool = False) -> Tuple[Tensor, Tensor]

Draw source and target samples from pi $(x,z) \sim \pi$.

Arguments:

  • pi Tensor - shape (bs, bs), the OT matrix between noise and data in minibatch.
  • batch_size int - The batch size of the minibatch.
  • replace bool - sampling w/ or w/o replacement from the OT plan, default to False.

Returns:

  • Tuple - tuple of 2 tensors, represents the indices of noise and data samples from pi.

kabsch_align

def kabsch_align(target: Tensor, noise: Tensor) -> Tensor

Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise.

Arguments:

  • target Tensor - shape (N, *dim), data from source minibatch.
  • noise Tensor - shape (N, *dim), noise from source minibatch.

Returns:

  • R Tensor - shape (dim, dim), the rotation matrix.

get_ot_matrix

def get_ot_matrix(x0: Tensor,
                  x1: Tensor,
                  mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]

Compute the OT matrix between a source and a target minibatch.

Arguments:

  • x0 Tensor - shape (bs, *dim), noise from source minibatch.
  • x1 Tensor - shape (bs, *dim), data from source minibatch.
  • mask Optional[Tensor], optional - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.

Returns:

  • p Tensor - shape (bs, bs), the OT matrix between noise and data in minibatch.
  • Rs Tensor - shape (bs, bs, dim, dim), the rotation matrix between noise and data in minibatch.

apply_augmentation

def apply_augmentation(
    x0: Tensor,
    x1: Tensor,
    mask: Optional[Tensor] = None,
    replace: Bool = False,
    sort: Optional[Literal["noise", "x0", "data", "x1"]] = "x0"
) -> Tuple[Tensor, Tensor, Optional[Tensor]]

Sample indices for noise and data in minibatch according to OT plan.

Compute the OT plan $\pi$ (wrt squared Euclidean cost after Kabsch alignment) between a source and a target
minibatch and draw source and target samples from pi $(x,z) \sim \pi$.

Arguments:

  • x0 Tensor - shape (bs, *dim), noise from source minibatch.
  • x1 Tensor - shape (bs, *dim), data from source minibatch.
  • mask Optional[Tensor], optional - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.
  • replace bool - sampling w/ or w/o replacement from the OT plan, default to False.
  • sort str - Optional Literal string to sort either x1 or x0 based on the input.

Returns:

  • Tuple - tuple of 2 tensors, represents the noise and data samples following OT plan pi.

bionemo.moco.interpolants.continuous_time.continuous.data_augmentation.kabsch_augmentation

KabschAugmentation Objects

class KabschAugmentation()

Point-wise Kabsch alignment.

__init__

def __init__()

Initialize the KabschAugmentation instance.

Notes:

  • This implementation assumes no required initialization arguments.
  • You can add instance variables (e.g., self.variable_name) as needed.

kabsch_align

def kabsch_align(target: Tensor, noise: Tensor)

Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise.

Arguments:

  • target Tensor - shape (N, *dim), data from source minibatch.
  • noise Tensor - shape (N, *dim), noise from source minibatch.

Returns:

  • R Tensor - shape (dim, dim), the rotation matrix.
    Aliged Target (Tensor): target tensor rotated and shifted to reduced RMSD with noise

batch_kabsch_align

def batch_kabsch_align(target: Tensor, noise: Tensor)

Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise.

Arguments:

  • target Tensor - shape (B, N, *dim), data from source minibatch.
  • noise Tensor - shape (B, N, *dim), noise from source minibatch.

Returns:

  • R Tensor - shape (dim, dim), the rotation matrix.
    Aliged Target (Tensor): target tensor rotated and shifted to reduced RMSD with noise

apply_augmentation

def apply_augmentation(x0: Tensor,
                       x1: Tensor,
                       mask: Optional[Tensor] = None,
                       align_noise_to_data=True) -> Tuple[Tensor, Tensor]

Sample indices for noise and data in minibatch according to OT plan.

Compute the OT plan $\pi$ (wrt squared Euclidean cost after Kabsch alignment) between a source and a target
minibatch and draw source and target samples from pi $(x,z) \sim \pi$.

Arguments:

  • x0 Tensor - shape (bs, *dim), noise from source minibatch.
  • x1 Tensor - shape (bs, *dim), data from source minibatch.
  • mask Optional[Tensor], optional - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.
  • replace bool - sampling w/ or w/o replacement from the OT plan, default to False.
  • align_noise_to_data bool - Direction of alignment default is True meaning it augments Noise to reduce error to Data.

Returns:

  • Tuple - tuple of 2 tensors, represents the noise and data samples following OT plan pi.

bionemo.moco.interpolants.continuous_time.continuous.data_augmentation

bionemo.moco.interpolants.continuous_time.continuous.data_augmentation.augmentation_types

AugmentationType Objects

class AugmentationType(Enum)

An enumeration representing the type ofOptimal Transport that can be used in Continuous Flow Matching.

These prediction types can be used to train neural networks for specific tasks, such as denoising, image synthesis, or time-series forecasting.

bionemo.moco.interpolants.continuous_time.continuous

bionemo.moco.interpolants.continuous_time.continuous.vdm

VDM Objects

class VDM(Interpolant)

A Variational Diffusion Models (VDM) interpolant.


Examples:

>>> import torch
>>> from bionemo.bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
>>> from bionemo.bionemo.moco.distributions.time.uniform import UniformTimeDistribution
>>> from bionemo.bionemo.moco.interpolants.discrete_time.continuous.vdm import VDM
>>> from bionemo.bionemo.moco.schedules.noise.continuous_snr_transforms import CosineSNRTransform
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule


vdm = VDM(
    time_distribution = UniformTimeDistribution(...),
    prior_distribution = GaussianPrior(...),
    noise_schedule = CosineSNRTransform(...),
    )
model = Model(...)

# Training
for epoch in range(1000):
    data = data_loader.get(...)
    time = vdm.sample_time(batch_size)
    noise = vdm.sample_prior(data.shape)
    xt = vdm.interpolate(data, noise, time)

    x_pred = model(xt, time)
    loss = vdm.loss(x_pred, data, time)
    loss.backward()

# Generation
x_pred = vdm.sample_prior(data.shape)
for t in LinearInferenceSchedule(...).generate_schedule():
    time = torch.full((batch_size,), t)
    x_hat = model(x_pred, time)
    x_pred = vdm.step(x_hat, time, x_pred)
return x_pred

__init__

def __init__(time_distribution: TimeDistribution,
             prior_distribution: PriorDistribution,
             noise_schedule: ContinuousSNRTransform,
             prediction_type: Union[PredictionType, str] = PredictionType.DATA,
             device: Union[str, torch.device] = "cpu",
             rng_generator: Optional[torch.Generator] = None)

Initializes the DDPM interpolant.

Arguments:

  • time_distribution TimeDistribution - The distribution of time steps, used to sample time points for the diffusion process.
  • prior_distribution PriorDistribution - The prior distribution of the variable, used as the starting point for the diffusion process.
  • noise_schedule ContinuousSNRTransform - The schedule of noise, defining the amount of noise added at each time step.
  • prediction_type PredictionType, optional - The type of prediction, either "data" or another type. Defaults to "data".
  • device str, optional - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

interpolate

def interpolate(data: Tensor, t: Tensor, noise: Tensor)

Get x(t) with given time t from noise and data.

Arguments:

  • data Tensor - target
  • t Tensor - time
  • noise Tensor - noise from prior()

forward_process

def forward_process(data: Tensor, t: Tensor, noise: Optional[Tensor] = None)

Get x(t) with given time t from noise and data.

Arguments:

  • data Tensor - target
  • t Tensor - time
  • noise Tensor, optional - noise from prior(). Defaults to None

process_data_prediction

def process_data_prediction(model_output: Tensor, sample, t)

Converts the model output to a data prediction based on the prediction type.

This conversion stems from the Progressive Distillation for Fast Sampling of Diffusion Models https://arxiv.org/pdf/2202.00512.
Given the model output and the sample, we convert the output to a data prediction based on the prediction type.
The conversion formulas are as follows:
- For "noise" prediction type: pred_data = (sample - noise_scale * model_output) / data_scale
- For "data" prediction type: pred_data = model_output
- For "v_prediction" prediction type: pred_data = data_scale * sample - noise_scale * model_output

Arguments:

  • model_output Tensor - The output of the model.
  • sample Tensor - The input sample.
  • t Tensor - The time step.

Returns:

The data prediction based on the prediction type.

Raises:

  • ValueError - If the prediction type is not one of "noise", "data", or "v_prediction".

process_noise_prediction

def process_noise_prediction(model_output: Tensor, sample: Tensor, t: Tensor)

Do the same as process_data_prediction but take the model output and convert to nosie.

Arguments:

  • model_output Tensor - The output of the model.
  • sample Tensor - The input sample.
  • t Tensor - The time step.

Returns:

The input as noise if the prediction type is "noise".

Raises:

  • ValueError - If the prediction type is not "noise".

step

def step(model_out: Tensor,
         t: Tensor,
         xt: Tensor,
         dt: Tensor,
         mask: Optional[Tensor] = None,
         center: Bool = False,
         temperature: Float = 1.0)

Do one step integration.

Arguments:

  • model_out Tensor - The output of the model.
  • xt Tensor - The current data point.
  • t Tensor - The current time step.
  • dt Tensor - The time step increment.
  • mask Optional[Tensor], optional - An optional mask to apply to the data. Defaults to None.
  • center bool - Whether to center the data. Defaults to False.
  • temperature Float - The temperature parameter for low temperature sampling. Defaults to 1.0.

Notes:

The temperature parameter controls the trade off between diversity and sample quality.
Decreasing the temperature sharpens the sampling distribtion to focus on more likely samples.
The impact of low temperature sampling must be ablated analytically.

score

def score(x_hat: Tensor, xt: Tensor, t: Tensor)

Converts the data prediction to the estimated score function.

Arguments:

  • x_hat tensor - The predicted data point.
  • xt Tensor - The current data point.
  • t Tensor - The time step.

Returns:

The estimated score function.

step_ddim

def step_ddim(model_out: Tensor,
              t: Tensor,
              xt: Tensor,
              dt: Tensor,
              mask: Optional[Tensor] = None,
              eta: Float = 0.0,
              center: Bool = False)

Do one step of DDIM sampling.

From the ddpm equations alpha_bar = alpha2 and 1 - alpha2 = sigma**2

Arguments:

  • model_out Tensor - output of the model
  • t Tensor - current time step
  • xt Tensor - current data point
  • dt Tensor - The time step increment.
  • mask Optional[Tensor], optional - mask for the data point. Defaults to None.
  • eta Float, optional - DDIM sampling parameter. Defaults to 0.0.
  • center Bool, optional - whether to center the data point. Defaults to False.

set_loss_weight_fn

def set_loss_weight_fn(fn: Callable)

Sets the loss_weight attribute of the instance to the given function.

Arguments:

  • fn - The function to set as the loss_weight attribute. This function should take three arguments: raw_loss, t, and weight_type.

loss_weight

def loss_weight(raw_loss: Tensor,
                t: Tensor,
                weight_type: str,
                dt: Float = 0.001) -> Tensor

Calculates the weight for the loss based on the given weight type.

This function computes the loss weight according to the specified weight_type.
The available weight types are:
- "ones": uniform weight of 1.0
- "data_to_noise": derived from Equation (9) of https://arxiv.org/pdf/2202.00512
- "variational_objective_discrete": based on the variational objective, see https://arxiv.org/pdf/2202.00512

Arguments:

  • raw_loss Tensor - The raw loss calculated from the model prediction and target.
  • t Tensor - The time step.
  • weight_type str - The type of weight to use. Can be "ones", "data_to_noise", or "variational_objective_discrete".
  • dt Float, optional - The time step increment. Defaults to 0.001.

Returns:

  • Tensor - The weight for the loss.

Raises:

  • ValueError - If the weight type is not recognized.

loss

def loss(model_pred: Tensor,
         target: Tensor,
         t: Tensor,
         dt: Optional[Float] = 0.001,
         mask: Optional[Tensor] = None,
         weight_type: str = "ones")

Calculates the loss given the model prediction, target, and time.

Arguments:

  • model_pred Tensor - The predicted output from the model.
  • target Tensor - The target output for the model prediction.
  • t Tensor - The time at which the loss is calculated.
  • dt Optional[Float], optional - The time step increment. Defaults to 0.001.
  • mask Optional[Tensor], optional - The mask for the data point. Defaults to None.
  • weight_type str, optional - The type of weight to use for the loss. Can be "ones", "data_to_noise", or "variational_objective". Defaults to "ones".

Returns:

  • Tensor - The calculated loss batch tensor.

step_hybrid_sde

def step_hybrid_sde(model_out: Tensor,
                    t: Tensor,
                    xt: Tensor,
                    dt: Tensor,
                    mask: Optional[Tensor] = None,
                    center: Bool = False,
                    temperature: Float = 1.0,
                    equilibrium_rate: Float = 0.0) -> Tensor

Do one step integration of Hybrid Langevin-Reverse Time SDE.

See section B.3 page 37 https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf.
and https://github.com/generatebio/chroma/blob/929407c605013613941803c6113adefdccaad679/chroma/layers/structure/diffusion.py#L730

Arguments:

  • model_out Tensor - The output of the model.
  • xt Tensor - The current data point.
  • t Tensor - The current time step.
  • dt Tensor - The time step increment.
  • mask Optional[Tensor], optional - An optional mask to apply to the data. Defaults to None.
  • center bool, optional - Whether to center the data. Defaults to False.
  • temperature Float, optional - The temperature parameter for low temperature sampling. Defaults to 1.0.
  • equilibrium_rate Float, optional - The rate of Langevin equilibration. Scales the amount of Langevin dynamics per unit time. Best values are in the range [1.0, 5.0]. Defaults to 0.0.

Notes:

For all step functions that use the SDE formulation its important to note that we are moving backwards in time which corresponds to an apparent sign change.
A clear example can be seen in slide 29 https://ernestryu.com/courses/FM/diffusion1.pdf.

step_ode

def step_ode(model_out: Tensor,
             t: Tensor,
             xt: Tensor,
             dt: Tensor,
             mask: Optional[Tensor] = None,
             center: Bool = False,
             temperature: Float = 1.0) -> Tensor

Do one step integration of ODE.

See section B page 36 https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf.
and https://github.com/generatebio/chroma/blob/929407c605013613941803c6113adefdccaad679/chroma/layers/structure/diffusion.py#L730

Arguments:

  • model_out Tensor - The output of the model.
  • xt Tensor - The current data point.
  • t Tensor - The current time step.
  • dt Tensor - The time step increment.
  • mask Optional[Tensor], optional - An optional mask to apply to the data. Defaults to None.
  • center bool, optional - Whether to center the data. Defaults to False.
  • temperature Float, optional - The temperature parameter for low temperature sampling. Defaults to 1.0.

bionemo.moco.interpolants.continuous_time.continuous.continuous_flow_matching

ContinuousFlowMatcher Objects

class ContinuousFlowMatcher(Interpolant)

A Continuous Flow Matching interpolant.


Examples:

>>> import torch
>>> from bionemo.bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
>>> from bionemo.bionemo.moco.distributions.time.uniform import UniformTimeDistribution
>>> from bionemo.bionemo.moco.interpolants.continuous_time.continuous.continuous_flow_matching import ContinuousFlowMatcher
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule

flow_matcher = ContinuousFlowMatcher(
    time_distribution = UniformTimeDistribution(...),
    prior_distribution = GaussianPrior(...),
    )
model = Model(...)

# Training
for epoch in range(1000):
    data = data_loader.get(...)
    time = flow_matcher.sample_time(batch_size)
    noise = flow_matcher.sample_prior(data.shape)
    data, time, noise = flow_matcher.apply_augmentation(noise, data) # Optional, only for OT
    xt = flow_matcher.interpolate(data, time, noise)
    flow = flow_matcher.calculate_target(data, noise)

    u_pred = model(xt, time)
    loss = flow_matcher.loss(u_pred, flow)
    loss.backward()

# Generation
x_pred = flow_matcher.sample_prior(data.shape)
inference_sched = LinearInferenceSchedule(...)
for t in inference_sched.generate_schedule():
    time = inference_sched.pad_time(x_pred.shape[0], t)
    u_hat = model(x_pred, time)
    x_pred = flow_matcher.step(u_hat, x_pred, time)
return x_pred

__init__

def __init__(time_distribution: TimeDistribution,
             prior_distribution: PriorDistribution,
             prediction_type: Union[PredictionType, str] = PredictionType.DATA,
             sigma: Float = 0,
             augmentation_type: Optional[Union[AugmentationType, str]] = None,
             augmentation_num_threads: int = 1,
             data_scale: Float = 1.0,
             device: Union[str, torch.device] = "cpu",
             rng_generator: Optional[torch.Generator] = None,
             eps: Float = 1e-5)

Initializes the Continuous Flow Matching interpolant.

Arguments:

  • time_distribution TimeDistribution - The distribution of time steps, used to sample time points for the diffusion process.
  • prior_distribution PriorDistribution - The prior distribution of the variable, used as the starting point for the diffusion process.
  • prediction_type PredictionType, optional - The type of prediction, either "flow" or another type. Defaults to PredictionType.DATA.
  • sigma Float, optional - The standard deviation of the Gaussian noise added to the interpolated data. Defaults to 0.
  • augmentation_type Optional[Union[AugmentationType, str]], optional - The type of optimal transport, if applicable. Defaults to None.
  • augmentation_num_threads - Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1.
  • data_scale Float, optional - The scale factor for the data. Defaults to 1.0.
  • device Union[str, torch.device], optional - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.
  • eps - Small float to prevent divide by zero

apply_augmentation

def apply_augmentation(x0: Tensor,
                       x1: Tensor,
                       mask: Optional[Tensor] = None,
                       **kwargs) -> tuple

Sample and apply the optimal transport plan between batched (and masked) x0 and x1.

Arguments:

  • x0 Tensor - shape (bs, *dim), noise from source minibatch.
  • x1 Tensor - shape (bs, *dim), data from source minibatch.
  • mask Optional[Tensor], optional - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.
  • **kwargs - Additional keyword arguments to be passed to self.augmentation_sampler.apply_augmentation or handled within this method.

Returns:

  • Tuple - tuple of 2 tensors, represents the noise and data samples following OT plan pi.

undo_scale_data

def undo_scale_data(data: Tensor) -> Tensor

Downscale the input data by the data scale factor.

Arguments:

  • data Tensor - The input data to downscale.

Returns:

The downscaled data.

scale_data

def scale_data(data: Tensor) -> Tensor

Upscale the input data by the data scale factor.

Arguments:

  • data Tensor - The input data to upscale.

Returns:

The upscaled data.

interpolate

def interpolate(data: Tensor, t: Tensor, noise: Tensor) -> Tensor

Get x_t with given time t from noise (x_0) and data (x_1).

Currently, we use the linear interpolation as defined in:
1. Rectified flow: https://arxiv.org/abs/2209.03003.
2. Conditional flow matching: https://arxiv.org/abs/2210.02747 (called conditional optimal transport).

Arguments:

  • noise Tensor - noise from prior(), shape (batchsize, nodes, features)
  • t Tensor - time, shape (batchsize)
  • data Tensor - target, shape (batchsize, nodes, features)

calculate_target

def calculate_target(data: Tensor,
                     noise: Tensor,
                     mask: Optional[Tensor] = None) -> Tensor

Get the target vector field at time t.

Arguments:

  • noise Tensor - noise from prior(), shape (batchsize, nodes, features)
  • data Tensor - target, shape (batchsize, nodes, features)
  • mask Optional[Tensor], optional - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.

Returns:

  • Tensor - The target vector field at time t.

process_vector_field_prediction

def process_vector_field_prediction(model_output: Tensor,
                                    xt: Optional[Tensor] = None,
                                    t: Optional[Tensor] = None,
                                    mask: Optional[Tensor] = None)

Process the model output based on the prediction type to calculate vecotr field.

Arguments:

  • model_output Tensor - The output of the model.
  • xt Tensor - The input sample.
  • t Tensor - The time step.
  • mask Optional[Tensor], optional - An optional mask to apply to the model output. Defaults to None.

Returns:

The vector field prediction based on the prediction type.

Raises:

  • ValueError - If the prediction type is not "flow" or "data".

process_data_prediction

def process_data_prediction(model_output: Tensor,
                            xt: Optional[Tensor] = None,
                            t: Optional[Tensor] = None,
                            mask: Optional[Tensor] = None)

Process the model output based on the prediction type to generate clean data.

Arguments:

  • model_output Tensor - The output of the model.
  • xt Tensor - The input sample.
  • t Tensor - The time step.
  • mask Optional[Tensor], optional - An optional mask to apply to the model output. Defaults to None.

Returns:

The data prediction based on the prediction type.

Raises:

  • ValueError - If the prediction type is not "flow".

step

def step(model_out: Tensor,
         xt: Tensor,
         dt: Tensor,
         t: Optional[Tensor] = None,
         mask: Optional[Tensor] = None,
         center: Bool = False)

Perform a single ODE step integration using Euler method.

Arguments:

  • model_out Tensor - The output of the model at the current time step.
  • xt Tensor - The current intermediate state.
  • dt Tensor - The time step size.
  • t Tensor, optional - The current time. Defaults to None.
  • mask Optional[Tensor], optional - A mask to apply to the model output. Defaults to None.
  • center Bool, optional - Whether to center the output. Defaults to False.

Returns:

  • x_next Tensor - The updated state of the system after the single step, x_(t+dt).

Notes:

  • If a mask is provided, it is applied element-wise to the model output before scaling.
  • The clean method is called on the updated state before it is returned.

step_score_stochastic

def step_score_stochastic(model_out: Tensor,
                          xt: Tensor,
                          dt: Tensor,
                          t: Tensor,
                          mask: Optional[Tensor] = None,
                          gt_mode: str = "tan",
                          gt_p: Float = 1.0,
                          gt_clamp: Optional[Float] = None,
                          score_temperature: Float = 1.0,
                          noise_temperature: Float = 1.0,
                          t_lim_ode: Float = 0.99,
                          center: Bool = False)

Perform a single SDE step integration using a score-based Langevin update.

d x_t = [v(x_t, t) + g(t) * s(x_t, t) * score_temperature] dt + \sqrt{2 * g(t) * noise_temperature} dw_t.

Arguments:

  • model_out Tensor - The output of the model at the current time step.
  • xt Tensor - The current intermediate state.
  • dt Tensor - The time step size.
  • t Tensor, optional - The current time. Defaults to None.
  • mask Optional[Tensor], optional - A mask to apply to the model output. Defaults to None.
  • gt_mode str, optional - The mode for the gt function. Defaults to "tan".
  • gt_p Float, optional - The parameter for the gt function. Defaults to 1.0.
  • gt_clamp - (Float, optional): Upper limit of gt term. Defaults to None.
  • score_temperature Float, optional - The temperature for the score part of the step. Defaults to 1.0.
  • noise_temperature Float, optional - The temperature for the stochastic part of the step. Defaults to 1.0.
  • t_lim_ode Float, optional - The time limit for the ODE step. Defaults to 0.99.
  • center Bool, optional - Whether to center the output. Defaults to False.

Returns:

  • x_next Tensor - The updated state of the system after the single step, x_(t+dt).

Notes:

  • If a mask is provided, it is applied element-wise to the model output before scaling.
  • The clean method is called on the updated state before it is returned.

loss

def loss(model_pred: Tensor,
         target: Tensor,
         t: Optional[Tensor] = None,
         xt: Optional[Tensor] = None,
         mask: Optional[Tensor] = None,
         target_type: Union[PredictionType, str] = PredictionType.DATA)

Calculate the loss given the model prediction, data sample, time, and mask.

If target_type is FLOW loss = ||v_hat - (x1-x0)||2
If target_type is DATA loss = ||x1_hat - x1||
2 * 1 / (1 - t)**2 as the target vector field = x1 - x0 = (1/(1-t)) * x1 - xt where xt = tx1 - (1-t)x0.
This functions supports any cominbation of prediction_type and target_type in {DATA, FLOW}.

Arguments:

  • model_pred Tensor - The predicted output from the model.
  • target Tensor - The target output for the model prediction.
  • t Optional[Tensor], optional - The time for the model prediction. Defaults to None.
  • xt Optional[Tensor], optional - The interpolated data. Defaults to None.
  • mask Optional[Tensor], optional - The mask for the data point. Defaults to None.
  • target_type PredictionType, optional - The type of the target output. Defaults to PredictionType.DATA.

Returns:

  • Tensor - The calculated loss batch tensor.

vf_to_score

def vf_to_score(x_t: Tensor, v: Tensor, t: Tensor) -> Tensor

From Geffner et al. Computes score of noisy density given the vector field learned by flow matching.

With our interpolation scheme these are related by

v(x_t, t) = (1 / t) (x_t + scale_ref ** 2 * (1 - t) * s(x_t, t)),

or equivalently,

s(x_t, t) = (t * v(x_t, t) - x_t) / (scale_ref ** 2 * (1 - t)).

with scale_ref = 1

Arguments:

  • x_t - Noisy sample, shape [*, dim]
  • v - Vector field, shape [*, dim]
  • t - Interpolation time, shape [*] (must be < 1)

Returns:

Score of intermediate density, shape [*, dim].

get_gt

def get_gt(t: Tensor,
           mode: str = "tan",
           param: float = 1.0,
           clamp_val: Optional[float] = None,
           eps: float = 1e-2) -> Tensor

From Geffner et al. Computes gt for different modes.

Arguments:

  • t - times where we'll evaluate, covers [0, 1), shape [nsteps]
  • mode - "us" or "tan"
  • param - parameterized transformation
  • clamp_val - value to clamp gt, no clamping if None
  • eps - small value leave as it is

bionemo.moco.interpolants.continuous_time

bionemo.moco.interpolants

bionemo.moco.interpolants.batch_augmentation

BatchDataAugmentation Objects

class BatchDataAugmentation()

Facilitates the creation of batch augmentation objects based on specified optimal transport types.

Arguments:

  • device str - The device to use for computations (e.g., 'cpu', 'cuda').
  • num_threads int - The number of threads to utilize.

__init__

def __init__(device, num_threads)

Initializes a BatchAugmentation instance.

Arguments:

  • device str - Device for computation.
  • num_threads int - Number of threads to use.

create

def create(method_type: AugmentationType)

Creates a batch augmentation object of the specified type.

Arguments:

  • method_type AugmentationType - The type of optimal transport method.

Returns:

The augmentation object if the type is supported, otherwise None.

bionemo.moco.interpolants.discrete_time.discrete.d3pm

D3PM Objects

class D3PM(Interpolant)

A Discrete Denoising Diffusion Probabilistic Model (D3PM) interpolant.

__init__

def __init__(time_distribution: TimeDistribution,
             prior_distribution: DiscretePriorDistribution,
             noise_schedule: DiscreteNoiseSchedule,
             device: str = "cpu",
             last_time_idx: int = 0,
             rng_generator: Optional[torch.Generator] = None)

Initializes the D3PM interpolant.

Arguments:

  • time_distribution TimeDistribution - The distribution of time steps, used to sample time points for the diffusion process.
  • prior_distribution PriorDistribution - The prior distribution of the variable, used as the starting point for the diffusion process.
  • noise_schedule DiscreteNoiseSchedule - The schedule of noise, defining the amount of noise added at each time step.
  • device str, optional - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
  • last_time_idx int, optional - The last time index to consider in the interpolation process. Defaults to 0.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

interpolate

def interpolate(data: Tensor, t: Tensor)

Interpolate using discrete interpolation method.

This method implements Equation 2 from the D3PM paper (https://arxiv.org/pdf/2107.03006), which
calculates the interpolated discrete state xt at time t given the input data and noise
via q(xt|x0) = Cat(xt; p = x0*Qt_bar).

Arguments:

  • data Tensor - The input data to be interpolated.
  • t Tensor - The time step at which to interpolate.

Returns:

  • Tensor - The interpolated discrete state xt at time t.

forward_process

def forward_process(data: Tensor, t: Tensor) -> Tensor

Apply the forward process to the data at time t.

Arguments:

  • data Tensor - target discrete ids
  • t Tensor - time

Returns:

  • Tensor - x(t) after applying the forward process

step

def step(model_out: Tensor,
         t: Tensor,
         xt: Tensor,
         mask: Optional[Tensor] = None,
         temperature: Float = 1.0,
         model_out_is_logits: bool = True)

Perform a single step in the discrete interpolant method, transitioning from the current discrete state xt at time t to the next state.

This step involves:

  1. Computing the predicted q-posterior logits using the model output model_out and the current state xt at time t.
  2. Sampling the next state from the predicted q-posterior distribution using the Gumbel-Softmax trick.

Arguments:

  • model_out Tensor - The output of the model at the current time step, which is used to compute the predicted q-posterior logits.
  • t Tensor - The current time step, which is used to index into the transition matrices and compute the predicted q-posterior logits.
  • xt Tensor - The current discrete state at time t, which is used to compute the predicted q-posterior logits and sample the next state.
  • mask Optional[Tensor], optional - An optional mask to apply to the next state, which can be used to mask out certain tokens or regions. Defaults to None.
  • temperature Float, optional - The temperature to use for the Gumbel-Softmax trick, which controls the randomness of the sampling process. Defaults to 1.0.
  • model_out_is_logits bool, optional - A flag indicating whether the model output is already in logits form. If True, the output is assumed to be logits; otherwise, it is converted to logits. Defaults to True.

Returns:

  • Tensor - The next discrete state at time t-1.

loss

def loss(logits: Tensor,
         target: Tensor,
         xt: Tensor,
         time: Tensor,
         mask: Optional[Tensor] = None,
         vb_scale: Float = 0.0)

Calculate the cross-entropy loss between the model prediction and the target output.

The loss is calculated between the batch x node x class logits and the target batch x node. If a mask is provided, the loss is
calculated only for the non-masked elements. Additionally, if vb_scale is greater than 0, the variational lower bound loss is
calculated and added to the total loss.

Arguments:

  • logits Tensor - The predicted output from the model, with shape batch x node x class.
  • target Tensor - The target output for the model prediction, with shape batch x node.
  • xt Tensor - The current data point.
  • time Tensor - The time at which the loss is calculated.
  • mask Optional[Tensor], optional - The mask for the data point. Defaults to None.
  • vb_scale Float, optional - The scale factor for the variational lower bound loss. Defaults to 0.0.

Returns:

  • Tensor - The calculated loss tensor. If aggregate is True, the loss and variational lower bound loss are aggregated and
    returned as a single tensor. Otherwise, the loss and variational lower bound loss are returned as separate tensors.

bionemo.moco.interpolants.discrete_time.discrete

bionemo.moco.interpolants.discrete_time.continuous.ddpm

DDPM Objects

class DDPM(Interpolant)

A Denoising Diffusion Probabilistic Model (DDPM) interpolant.


Examples:

>>> import torch
>>> from bionemo.bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
>>> from bionemo.bionemo.moco.distributions.time.uniform import UniformTimeDistribution
>>> from bionemo.bionemo.moco.interpolants.discrete_time.continuous.ddpm import DDPM
>>> from bionemo.bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteCosineNoiseSchedule
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import DiscreteLinearInferenceSchedule


ddpm = DDPM(
    time_distribution = UniformTimeDistribution(discrete_time = True,...),
    prior_distribution = GaussianPrior(...),
    noise_schedule = DiscreteCosineNoiseSchedule(...),
    )
model = Model(...)

# Training
for epoch in range(1000):
    data = data_loader.get(...)
    time = ddpm.sample_time(batch_size)
    noise = ddpm.sample_prior(data.shape)
    xt = ddpm.interpolate(data, noise, time)

    x_pred = model(xt, time)
    loss = ddpm.loss(x_pred, data, time)
    loss.backward()

# Generation
x_pred = ddpm.sample_prior(data.shape)
for t in DiscreteLinearTimeSchedule(...).generate_schedule():
    time = torch.full((batch_size,), t)
    x_hat = model(x_pred, time)
    x_pred = ddpm.step(x_hat, time, x_pred)
return x_pred

__init__

def __init__(time_distribution: TimeDistribution,
             prior_distribution: PriorDistribution,
             noise_schedule: DiscreteNoiseSchedule,
             prediction_type: Union[PredictionType, str] = PredictionType.DATA,
             device: Union[str, torch.device] = "cpu",
             last_time_idx: int = 0,
             rng_generator: Optional[torch.Generator] = None)

Initializes the DDPM interpolant.

Arguments:

  • time_distribution TimeDistribution - The distribution of time steps, used to sample time points for the diffusion process.
  • prior_distribution PriorDistribution - The prior distribution of the variable, used as the starting point for the diffusion process.
  • noise_schedule DiscreteNoiseSchedule - The schedule of noise, defining the amount of noise added at each time step.
  • prediction_type PredictionType - The type of prediction, either "data" or another type. Defaults to "data".
  • device str - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
  • last_time_idx int, optional - The last time index for discrete time. Set to 0 if discrete time is T-1, ..., 0 or 1 if T, ..., 1. Defaults to 0.
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

forward_data_schedule

@property
def forward_data_schedule() -> torch.Tensor

Returns the forward data schedule.

forward_noise_schedule

@property
def forward_noise_schedule() -> torch.Tensor

Returns the forward noise schedule.

reverse_data_schedule

@property
def reverse_data_schedule() -> torch.Tensor

Returns the reverse data schedule.

reverse_noise_schedule

@property
def reverse_noise_schedule() -> torch.Tensor

Returns the reverse noise schedule.

log_var

@property
def log_var() -> torch.Tensor

Returns the log variance.

alpha_bar

@property
def alpha_bar() -> torch.Tensor

Returns the alpha bar values.

alpha_bar_prev

@property
def alpha_bar_prev() -> torch.Tensor

Returns the previous alpha bar values.

interpolate

def interpolate(data: Tensor, t: Tensor, noise: Tensor)

Get x(t) with given time t from noise and data.

Arguments:

  • data Tensor - target
  • t Tensor - time
  • noise Tensor - noise from prior()

forward_process

def forward_process(data: Tensor, t: Tensor, noise: Optional[Tensor] = None)

Get x(t) with given time t from noise and data.

Arguments:

  • data Tensor - target
  • t Tensor - time
  • noise Tensor, optional - noise from prior(). Defaults to None.

process_data_prediction

def process_data_prediction(model_output: Tensor, sample: Tensor, t: Tensor)

Converts the model output to a data prediction based on the prediction type.

This conversion stems from the Progressive Distillation for Fast Sampling of Diffusion Models https://arxiv.org/pdf/2202.00512.
Given the model output and the sample, we convert the output to a data prediction based on the prediction type.
The conversion formulas are as follows:
- For "noise" prediction type: pred_data = (sample - noise_scale * model_output) / data_scale
- For "data" prediction type: pred_data = model_output
- For "v_prediction" prediction type: pred_data = data_scale * sample - noise_scale * model_output

Arguments:

  • model_output Tensor - The output of the model.
  • sample Tensor - The input sample.
  • t Tensor - The time step.

Returns:

The data prediction based on the prediction type.

Raises:

  • ValueError - If the prediction type is not one of "noise", "data", or "v_prediction".

process_noise_prediction

def process_noise_prediction(model_output, sample, t)

Do the same as process_data_prediction but take the model output and convert to nosie.

Arguments:

  • model_output - The output of the model.
  • sample - The input sample.
  • t - The time step.

Returns:

The input as noise if the prediction type is "noise".

Raises:

  • ValueError - If the prediction type is not "noise".

calculate_velocity

def calculate_velocity(data: Tensor, t: Tensor, noise: Tensor) -> Tensor

Calculate the velocity term given the data, time step, and noise.

Arguments:

  • data Tensor - The input data.
  • t Tensor - The current time step.
  • noise Tensor - The noise term.

Returns:

  • Tensor - The calculated velocity term.

step

@torch.no_grad()
def step(model_out: Tensor,
         t: Tensor,
         xt: Tensor,
         mask: Optional[Tensor] = None,
         center: Bool = False,
         temperature: Float = 1.0)

Do one step integration.

Arguments:

  • model_out Tensor - The output of the model.
  • t Tensor - The current time step.
  • xt Tensor - The current data point.
  • mask Optional[Tensor], optional - An optional mask to apply to the data. Defaults to None.
  • center bool, optional - Whether to center the data. Defaults to False.
  • temperature Float, optional - The temperature parameter for low temperature sampling. Defaults to 1.0.

Notes:

The temperature parameter controls the level of randomness in the sampling process. A temperature of 1.0 corresponds to standard diffusion sampling, while lower temperatures (e.g. 0.5, 0.2) result in less random and more deterministic samples. This can be useful for tasks that require more control over the generation process.

Note for discrete time we sample from [T-1, ..., 1, 0] for T steps so we sample t = 0 hence the mask.
For continuous time we start from [1, 1 -dt, ..., dt] for T steps where s = t - 1 when t = 0 i.e dt is then 0

step_noise

def step_noise(model_out: Tensor,
               t: Tensor,
               xt: Tensor,
               mask: Optional[Tensor] = None,
               center: Bool = False,
               temperature: Float = 1.0)

Do one step integration.

Arguments:

  • model_out Tensor - The output of the model.
  • t Tensor - The current time step.
  • xt Tensor - The current data point.
  • mask Optional[Tensor], optional - An optional mask to apply to the data. Defaults to None.
  • center bool, optional - Whether to center the data. Defaults to False.
  • temperature Float, optional - The temperature parameter for low temperature sampling. Defaults to 1.0.

Notes:

The temperature parameter controls the level of randomness in the sampling process.
A temperature of 1.0 corresponds to standard diffusion sampling, while lower temperatures (e.g. 0.5, 0.2)
result in less random and more deterministic samples. This can be useful for tasks
that require more control over the generation process.

Note for discrete time we sample from [T-1, ..., 1, 0] for T steps so we sample t = 0 hence the mask.
For continuous time we start from [1, 1 -dt, ..., dt] for T steps where s = t - 1 when t = 0 i.e dt is then 0

score

def score(x_hat: Tensor, xt: Tensor, t: Tensor)

Converts the data prediction to the estimated score function.

Arguments:

  • x_hat Tensor - The predicted data point.
  • xt Tensor - The current data point.
  • t Tensor - The time step.

Returns:

The estimated score function.

step_ddim

def step_ddim(model_out: Tensor,
              t: Tensor,
              xt: Tensor,
              mask: Optional[Tensor] = None,
              eta: Float = 0.0,
              center: Bool = False)

Do one step of DDIM sampling.

Arguments:

  • model_out Tensor - output of the model
  • t Tensor - current time step
  • xt Tensor - current data point
  • mask Optional[Tensor], optional - mask for the data point. Defaults to None.
  • eta Float, optional - DDIM sampling parameter. Defaults to 0.0.
  • center Bool, optional - whether to center the data point. Defaults to False.

set_loss_weight_fn

def set_loss_weight_fn(fn)

Sets the loss_weight attribute of the instance to the given function.

Arguments:

  • fn - The function to set as the loss_weight attribute. This function should take three arguments: raw_loss, t, and weight_type.

loss_weight

def loss_weight(raw_loss: Tensor, t: Optional[Tensor],
                weight_type: str) -> Tensor

Calculates the weight for the loss based on the given weight type.

These data_to_noise loss weights is derived in Equation (9) of https://arxiv.org/pdf/2202.00512.

Arguments:

  • raw_loss Tensor - The raw loss calculated from the model prediction and target.
  • t Tensor - The time step.
  • weight_type str - The type of weight to use. Can be "ones" or "data_to_noise" or "noise_to_data".

Returns:

  • Tensor - The weight for the loss.

Raises:

  • ValueError - If the weight type is not recognized.

loss

def loss(model_pred: Tensor,
         target: Tensor,
         t: Optional[Tensor] = None,
         mask: Optional[Tensor] = None,
         weight_type: Literal["ones", "data_to_noise",
                              "noise_to_data"] = "ones")

Calculate the loss given the model prediction, data sample, and time.

The default weight_type is "ones" meaning no change / multiplying by all ones.
data_to_noise is available to scale the data MSE loss into the appropriate loss that is theoretically equivalent
to noise prediction. noise_to_data is provided for a similar reason for completeness.

Arguments:

  • model_pred Tensor - The predicted output from the model.
  • target Tensor - The target output for the model prediction.
  • t Tensor - The time at which the loss is calculated.
  • mask Optional[Tensor], optional - The mask for the data point. Defaults to None.
  • weight_type Literal["ones", "data_to_noise", "noise_to_data"] - The type of weight to use for the loss. Defaults to "ones".

Returns:

  • Tensor - The calculated loss batch tensor.

bionemo.moco.interpolants.discrete_time.continuous

bionemo.moco.interpolants.discrete_time

bionemo.moco.interpolants.discrete_time.utils

safe_index

def safe_index(tensor: Tensor, index: Tensor, device: Optional[torch.device])

Safely indexes a tensor using a given index and returns the result on a specified device.

Note can implement forcing with return tensor[index.to(tensor.device)].to(device) but has costly migration.

Arguments:

  • tensor Tensor - The tensor to be indexed.
  • index Tensor - The index to use for indexing the tensor.
  • device torch.device - The device on which the result should be returned.

Returns:

  • Tensor - The indexed tensor on the specified device.

Raises:

  • ValueError - If tensor, index are not all on the same device.

bionemo.moco.interpolants.base_interpolant

string_to_enum

def string_to_enum(value: Union[str, AnyEnum],
                   enum_type: Type[AnyEnum]) -> AnyEnum

Converts a string to an enum value of the specified type. If the input is already an enum instance, it is returned as-is.

Arguments:

  • value Union[str, E] - The string to convert or an existing enum instance.
  • enum_type Type[E] - The enum type to convert to.

Returns:

  • E - The corresponding enum value.

Raises:

  • ValueError - If the string does not correspond to any enum member.

pad_like

def pad_like(source: Tensor, target: Tensor) -> Tensor

Pads the dimensions of the source tensor to match the dimensions of the target tensor.

Arguments:

  • source Tensor - The tensor to be padded.
  • target Tensor - The tensor that the source tensor should match in dimensions.

Returns:

  • Tensor - The padded source tensor.

Raises:

  • ValueError - If the source tensor has more dimensions than the target tensor.

Example:

source = torch.tensor([1, 2, 3]) # shape: (3,)
target = torch.tensor([[1, 2], [4, 5], [7, 8]]) # shape: (3, 2)
padded_source = pad_like(source, target) # shape: (3, 1)

PredictionType Objects

class PredictionType(Enum)

An enumeration representing the type of prediction a Denoising Diffusion Probabilistic Model (DDPM) can be used for.

DDPMs are versatile models that can be utilized for various prediction tasks, including:

  • Data: Predicting the original data distribution from a noisy input.
  • Noise: Predicting the noise that was added to the original data to obtain the input.
  • Velocity: Predicting the velocity or rate of change of the data, particularly useful for modeling temporal dynamics.

These prediction types can be used to train neural networks for specific tasks, such as denoising, image synthesis, or time-series forecasting.

Interpolant Objects

class Interpolant(ABC)

An abstract base class representing an Interpolant.

This class serves as a foundation for creating interpolants that can be used
in various applications, providing a basic structure and interface for
interpolation-related operations.

__init__

def __init__(time_distribution: TimeDistribution,
             prior_distribution: PriorDistribution,
             device: Union[str, torch.device] = "cpu",
             rng_generator: Optional[torch.Generator] = None)

Initializes the Interpolant class.

Arguments:

  • time_distribution TimeDistribution - The distribution of time steps.
  • prior_distribution PriorDistribution - The prior distribution of the variable.
  • device Union[str, torch.device], optional - The device on which to operate. Defaults to "cpu".
  • rng_generator - An optional :class:torch.Generator for reproducible sampling. Defaults to None.

interpolate

@abstractmethod
def interpolate(*args, **kwargs) -> Tensor

Get x(t) with given time t from noise and data.

Interpolate between x0 and x1 at the given time t.

step

@abstractmethod
def step(*args, **kwargs) -> Tensor

Do one step integration.

general_step

def general_step(method_name: str, kwargs: dict)

Calls a step method of the class by its name, passing the provided keyword arguments.

Arguments:

  • method_name str - The name of the step method to call.
  • kwargs dict - Keyword arguments to pass to the step method.

Returns:

The result of the step method call.

Raises:

  • ValueError - If the provided method name does not start with 'step'.
  • Exception - If the step method call fails. The error message includes a list of available step methods.

Notes:

This method allows for dynamic invocation of step methods, providing flexibility in the class's usage.

sample_prior

def sample_prior(*args, **kwargs) -> Tensor

Sample from prior distribution.

This method generates a sample from the prior distribution specified by the
prior_distribution attribute.

Returns:

  • Tensor - The generated sample from the prior distribution.

sample_time

def sample_time(*args, **kwargs) -> Tensor

Sample from time distribution.

to_device

def to_device(device: str)

Moves all internal tensors to the specified device and updates the self.device attribute.

Arguments:

  • device str - The device to move the tensors to (e.g. "cpu", "cuda:0").

Notes:

This method is used to transfer the internal state of the DDPM interpolant to a different device.
It updates the self.device attribute to reflect the new device and moves all internal tensors to the specified device.

clean_mask_center

def clean_mask_center(data: Tensor,
                      mask: Optional[Tensor] = None,
                      center: Bool = False) -> Tensor

Returns a clean tensor that has been masked and/or centered based on the function arguments.

Arguments:

  • data - The input data with shape (..., nodes, features).
  • mask - An optional mask to apply to the data with shape (..., nodes). If provided, it is used to calculate the CoM. Defaults to None.
  • center - A boolean indicating whether to center the data around the calculated CoM. Defaults to False.

Returns:

The data with shape (..., nodes, features) either centered around the CoM if center is True or unchanged if center is False.

bionemo.moco.testing

bionemo.moco.testing.parallel_test_utils

parallel_context

@contextmanager
def parallel_context(rank: int = 0, world_size: int = 1)

Context manager for torch distributed testing.

Sets up and cleans up the distributed environment, including the device mesh.

Arguments:

  • rank int - The rank of the process. Defaults to 0.
  • world_size int - The world size of the distributed environment. Defaults to 1.

Yields:

None

clean_up_distributed

def clean_up_distributed() -> None

Cleans up the distributed environment.

Destroys the process group and empties the CUDA cache.

Arguments:

None

Returns:

None