class PriorDistribution(ABC)
An abstract base class representing a prior distribution.
@abstractmethod
def sample(shape: Tuple,
mask: Optional[Tensor] = None,
device: Union[str, torch.device] = "cpu") -> Tensor
Generates a specified number of samples from the time distribution.
Arguments:
shape
Tuple - The shape of the samples to generate.mask
Optional[Tensor], optional - A tensor indicating which samples should be masked. Defaults to None.device
str, optional - The device on which to generate the samples. Defaults to "cpu".Returns:
Float
- A tensor of samples.class DiscretePriorDistribution(PriorDistribution)
An abstract base class representing a discrete prior distribution.
def __init__(num_classes: int, prior_dist: Tensor)
Initializes a DiscretePriorDistribution instance.
Arguments:
num_classes
int - The number of classes in the discrete distribution.prior_dist
Tensor - The prior distribution over the classes.Returns:
None
def get_num_classes() -> int
Getter for num_classes.
def get_prior_dist() -> Tensor
Getter for prior_dist.
class DiscreteUniformPrior(DiscretePriorDistribution)
A subclass representing a discrete uniform prior distribution.
def __init__(num_classes: int = 10) -> None
Initializes a discrete uniform prior distribution.
Arguments:
num_classes
int - The number of classes in the discrete uniform distribution. Defaults to 10.def sample(shape: Tuple,
mask: Optional[Tensor] = None,
device: Union[str, torch.device] = "cpu",
rng_generator: Optional[torch.Generator] = None) -> Tensor
Generates a specified number of samples.
Arguments:
shape
Tuple - The shape of the samples to generate.device
str - cpu or gpu.mask
Optional[Tensor] - An optional mask to apply to the samples. Defaults to None.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.Returns:
Float
- A tensor of samples.class DiscreteCustomPrior(DiscretePriorDistribution)
A subclass representing a discrete custom prior distribution.
This class allows for the creation of a prior distribution with a custom
probability mass function defined by the prior_dist
tensor. For example if my data has 4 classes and I want [.3, .2, .4, .1] as the probabilities of the 4 classes.
def __init__(prior_dist: Tensor, num_classes: int = 10) -> None
Initializes a DiscreteCustomPrior distribution.
Arguments:
prior_dist
- A tensor representing the probability mass function of the prior distribution.num_classes
- The number of classes in the prior distribution. Defaults to 10.Notes:
The prior_dist
tensor should have a sum close to 1.0, as it represents a probability mass function.
def sample(shape: Tuple,
mask: Optional[Tensor] = None,
device: Union[str, torch.device] = "cpu",
rng_generator: Optional[torch.Generator] = None) -> Tensor
Samples from the discrete custom prior distribution.
Arguments:
shape
- A tuple specifying the shape of the samples to generate.mask
- An optional tensor mask to apply to the samples, broadcastable to the sample shape. Defaults to None.device
- The device on which to generate the samples, specified as a string or a :class:torch.device
. Defaults to "cpu".rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.Returns:
A tensor of samples drawn from the prior distribution.
class DiscreteMaskedPrior(DiscretePriorDistribution)
A subclass representing a Discrete Masked prior distribution.
def __init__(num_classes: int = 10,
mask_dim: Optional[int] = None,
inclusive: bool = True) -> None
Discrete Masked prior distribution.
Theres 3 ways I can think of defining the problem that are hard to mesh together.
Arguments:
num_classes
int - The number of classes in the distribution. Defaults to 10.mask_dim
int - The index for the mask token. Defaults to num_classes - 1 if inclusive or num_classes if exclusive.inclusive
bool - Whether the mask is included in the specified number of classes.def sample(shape: Tuple,
mask: Optional[Tensor] = None,
device: Union[str, torch.device] = "cpu",
rng_generator: Optional[torch.Generator] = None) -> Tensor
Generates a specified number of samples.
Arguments:
shape
Tuple - The shape of the samples to generate.device
str - cpu or gpu.mask
Optional[Tensor] - An optional mask to apply to the samples. Defaults to None.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.Returns:
Float
- A tensor of samples.def is_masked(sample: Tensor) -> Tensor
Creates a mask for whether a state is masked.
Arguments:
sample
Tensor - The sample to check.Returns:
Tensor
- A float tensor indicating whether the sample is masked.def pad_sample(sample: Tensor) -> Tensor
Pads the input sample with zeros along the last dimension.
Arguments:
sample
Tensor - The input sample to be padded.Returns:
Tensor
- The padded sample.class LinearHarmonicPrior(PriorDistribution)
A subclass representing a Linear Harmonic prior distribution from Jing et al. https://arxiv.org/abs/2304.02198.
def __init__(length: Optional[int] = None,
distance: Float = 3.8,
center: Bool = False,
rng_generator: Optional[torch.Generator] = None,
device: Union[str, torch.device] = "cpu") -> None
Linear Harmonic prior distribution.
Arguments:
length
Optional[int] - The number of points in a batch.distance
Float - RMS distance between adjacent points in the line graph.center
bool - Whether to center the samples around the mean. Defaults to False.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.device
Optional[str] - Device to place the schedule on (default is "cpu").def sample(shape: Tuple,
mask: Optional[Tensor] = None,
device: Union[str, torch.device] = "cpu",
rng_generator: Optional[torch.Generator] = None) -> Tensor
Generates a specified number of samples from the Harmonic prior distribution.
Arguments:
shape
Tuple - The shape of the samples to generate.device
str - cpu or gpu.mask
Optional[Tensor] - An optional mask to apply to the samples. Defaults to None.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.Returns:
Float
- A tensor of samples.class GaussianPrior(PriorDistribution)
A subclass representing a Gaussian prior distribution.
def __init__(mean: Float = 0.0,
std: Float = 1.0,
center: Bool = False,
rng_generator: Optional[torch.Generator] = None) -> None
Gaussian prior distribution.
Arguments:
mean
Float - The mean of the Gaussian distribution. Defaults to 0.0.std
Float - The standard deviation of the Gaussian distribution. Defaults to 1.0.center
bool - Whether to center the samples around the mean. Defaults to False.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.def sample(shape: Tuple,
mask: Optional[Tensor] = None,
device: Union[str, torch.device] = "cpu",
rng_generator: Optional[torch.Generator] = None) -> Tensor
Generates a specified number of samples from the Gaussian prior distribution.
Arguments:
shape
Tuple - The shape of the samples to generate.device
str - cpu or gpu.mask
Optional[Tensor] - An optional mask to apply to the samples. Defaults to None.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.Returns:
Float
- A tensor of samples.def remove_center_of_mass(data: Tensor,
mask: Optional[Tensor] = None) -> Tensor
Calculates the center of mass (CoM) of the given data.
Arguments:
data
- The input data with shape (..., nodes, features).mask
- An optional binary mask to apply to the data with shape (..., nodes) to mask out interaction from CoM calculation. Defaults to None.Returns:
The CoM of the data with shape (..., 1, features).
class TimeDistribution(ABC)
An abstract base class representing a time distribution.
Arguments:
discrete_time
Bool - Whether the time is discrete.nsteps
Optional[int] - Number of nsteps for discretization.min_t
Optional[Float] - Min continuous time.max_t
Optional[Float] - Max continuous time.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.def __init__(discrete_time: Bool = False,
nsteps: Optional[int] = None,
min_t: Optional[Float] = None,
max_t: Optional[Float] = None,
rng_generator: Optional[torch.Generator] = None)
Initializes a TimeDistribution object.
@abstractmethod
def sample(n_samples: Union[int, Tuple[int, ...], torch.Size],
device: Union[str, torch.device] = "cpu",
rng_generator: Optional[torch.Generator] = None) -> Float
Generates a specified number of samples from the time distribution.
Arguments:
n_samples
int - The number of samples to generate.device
str - cpu or gpu.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.Returns:
Float
- A list or array of samples.class MixTimeDistribution()
An abstract base class representing a mixed time distribution.
uniform_dist = UniformTimeDistribution(min_t=0.0, max_t=1.0, discrete_time=False)
beta_dist = BetaTimeDistribution(min_t=0.0, max_t=1.0, discrete_time=False, p1=2.0, p2=1.0)
mix_dist = MixTimeDistribution(uniform_dist, beta_dist, mix_fraction=0.5)
def __init__(dist1: TimeDistribution, dist2: TimeDistribution,
mix_fraction: Float)
Initializes a MixTimeDistribution object.
Arguments:
dist1
TimeDistribution - The first time distribution.dist2
TimeDistribution - The second time distribution.mix_fraction
Float - The fraction of samples to draw from dist1. Must be between 0 and 1.def sample(n_samples: int,
device: Union[str, torch.device] = "cpu",
rng_generator: Optional[torch.Generator] = None) -> Float
Generates a specified number of samples from the mixed time distribution.
Arguments:
n_samples
int - The number of samples to generate.device
str - cpu or gpu.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.Returns:
Float
- A list or array of samples.class UniformTimeDistribution(TimeDistribution)
A class representing a uniform time distribution.
def __init__(min_t: Float = 0.0,
max_t: Float = 1.0,
discrete_time: Bool = False,
nsteps: Optional[int] = None,
rng_generator: Optional[torch.Generator] = None)
Initializes a UniformTimeDistribution object.
Arguments:
min_t
Float - The minimum time value.max_t
Float - The maximum time value.discrete_time
Bool - Whether the time is discrete.nsteps
Optional[int] - Number of nsteps for discretization.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.def sample(n_samples: Union[int, Tuple[int, ...], torch.Size],
device: Union[str, torch.device] = "cpu",
rng_generator: Optional[torch.Generator] = None)
Generates a specified number of samples from the uniform time distribution.
Arguments:
n_samples
int - The number of samples to generate.device
str - cpu or gpu.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.Returns:
A tensor of samples.
class SymmetricUniformTimeDistribution(TimeDistribution)
A class representing a uniform time distribution.
def __init__(min_t: Float = 0.0,
max_t: Float = 1.0,
discrete_time: Bool = False,
nsteps: Optional[int] = None,
rng_generator: Optional[torch.Generator] = None)
Initializes a UniformTimeDistribution object.
Arguments:
min_t
Float - The minimum time value.max_t
Float - The maximum time value.discrete_time
Bool - Whether the time is discrete.nsteps
Optional[int] - Number of nsteps for discretization.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.def sample(n_samples: Union[int, Tuple[int, ...], torch.Size],
device: Union[str, torch.device] = "cpu",
rng_generator: Optional[torch.Generator] = None)
Generates a specified number of samples from the uniform time distribution.
Arguments:
n_samples
int - The number of samples to generate.device
str - cpu or gpu.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.Returns:
A tensor of samples.
class LogitNormalTimeDistribution(TimeDistribution)
A class representing a logit normal time distribution.
def __init__(p1: Float = 0.0,
p2: Float = 1.0,
min_t: Float = 0.0,
max_t: Float = 1.0,
discrete_time: Bool = False,
nsteps: Optional[int] = None,
rng_generator: Optional[torch.Generator] = None)
Initializes a BetaTimeDistribution object.
Arguments:
p1
Float - The first shape parameter of the logit normal distribution i.e. the mean.p2
Float - The second shape parameter of the logit normal distribution i.e. the std.min_t
Float - The minimum time value.max_t
Float - The maximum time value.discrete_time
Bool - Whether the time is discrete.nsteps
Optional[int] - Number of nsteps for discretization.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.def sample(n_samples: Union[int, Tuple[int, ...], torch.Size],
device: Union[str, torch.device] = "cpu",
rng_generator: Optional[torch.Generator] = None)
Generates a specified number of samples from the uniform time distribution.
Arguments:
n_samples
int - The number of samples to generate.device
str - cpu or gpu.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.Returns:
A tensor of samples.
class BetaTimeDistribution(TimeDistribution)
A class representing a beta time distribution.
def __init__(p1: Float = 2.0,
p2: Float = 1.0,
min_t: Float = 0.0,
max_t: Float = 1.0,
discrete_time: Bool = False,
nsteps: Optional[int] = None,
rng_generator: Optional[torch.Generator] = None)
Initializes a BetaTimeDistribution object.
Arguments:
p1
Float - The first shape parameter of the beta distribution.p2
Float - The second shape parameter of the beta distribution.min_t
Float - The minimum time value.max_t
Float - The maximum time value.discrete_time
Bool - Whether the time is discrete.nsteps
Optional[int] - Number of nsteps for discretization.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.def sample(n_samples: Union[int, Tuple[int, ...], torch.Size],
device: Union[str, torch.device] = "cpu",
rng_generator: Optional[torch.Generator] = None)
Generates a specified number of samples from the uniform time distribution.
Arguments:
n_samples
int - The number of samples to generate.device
str - cpu or gpu.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.Returns:
A tensor of samples.
def float_time_to_index(time: torch.Tensor,
num_time_steps: int) -> torch.Tensor
Convert a float time value to a time index.
Arguments:
time
torch.Tensor - A tensor of float time values in the range [0, 1].num_time_steps
int - The number of discrete time steps.Returns:
torch.Tensor
- A tensor of time indices corresponding to the input float time values.def log(t, eps=1e-20)
Compute the natural logarithm of a tensor, clamping values to avoid numerical instability.
Arguments:
t
Tensor - The input tensor.eps
float, optional - The minimum value to clamp the input tensor (default is 1e-20).Returns:
Tensor
- The natural logarithm of the input tensor.class ContinuousSNRTransform(ABC)
A base class for continuous SNR schedules.
def __init__(direction: TimeDirection)
Initialize the DiscreteNoiseSchedule.
Arguments:
direction
TimeDirection - required this defines in which direction the scheduler was builtdef calculate_log_snr(t: Tensor,
device: Union[str, torch.device] = "cpu",
synchronize: Optional[TimeDirection] = None) -> Tensor
Public wrapper to generate the time schedule as a tensor.
Arguments:
t
Tensor - The input tensor representing the time steps, with values ranging from 0 to 1.device
Optional[str] - The device to place the schedule on. Defaults to "cpu".synchronize
optional[TimeDirection] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction,Returns:
Tensor
- A tensor representing the log signal-to-noise (SNR) ratio for the given time steps.def log_snr_to_alphas_sigmas(log_snr: Tensor) -> Tuple[Tensor, Tensor]
Converts log signal-to-noise ratio (SNR) to alpha and sigma values.
Arguments:
log_snr
Tensor - The input log SNR tensor.Returns:
tuple[Tensor, Tensor]: A tuple containing the squared root of alpha and sigma values.
def derivative(t: Tensor, func: Callable) -> Tensor
Compute derivative of a function, it supports bached single variable inputs.
Arguments:
t
Tensor - time variable at which derivatives are takenfunc
Callable - function for derivative calculationReturns:
Tensor
- derivative that is detached from the computational graphdef calculate_general_sde_terms(t)
Compute the general SDE terms for a given time step t.
Arguments:
t
Tensor - The input tensor representing the time step.Returns:
tuple[Tensor, Tensor]: A tuple containing the drift term f_t and the diffusion term g_t_2.
Notes:
This method computes the drift and diffusion terms of the general SDE, which can be used to simulate the stochastic process.
The drift term represents the deterministic part of the process, while the diffusion term represents the stochastic part.
def calculate_beta(t)
Compute the drift coefficient for the OU process of the form $dx = -\frac{1}{2} \beta(t) x dt + sqrt(beta(t)) dw_t$.
beta = d/dt log(alpha**2) = 2 * 1/alpha * d/dt(alpha)
Arguments:
t
Union[float, Tensor] - t in [0, 1]Returns:
Tensor
- beta(t)def calculate_alpha_log_snr(log_snr: Tensor) -> Tensor
Compute alpha values based on the log SNR.
Arguments:
log_snr
Tensor - The input tensor representing the log signal-to-noise ratio.Returns:
Tensor
- A tensor representing the alpha values for the given log SNR.Notes:
This method computes alpha values as the square root of the sigmoid of the log SNR.
def calculate_alpha_t(t: Tensor) -> Tensor
Compute alpha values based on the log SNR schedule.
Arguments:
t
Tensor - The input tensor representing the time steps.Returns:
Tensor
- A tensor representing the alpha values for the given time steps.Notes:
This method computes alpha values as the square root of the sigmoid of the log SNR.
class CosineSNRTransform(ContinuousSNRTransform)
A cosine SNR schedule.
Arguments:
nu
Optional[Float] - Hyperparameter for the cosine schedule exponent (default is 1.0).s
Optional[Float] - Hyperparameter for the cosine schedule shift (default is 0.008).def __init__(nu: Float = 1.0, s: Float = 0.008)
Initialize the CosineNoiseSchedule.
class LinearSNRTransform(ContinuousSNRTransform)
A Linear SNR schedule.
def __init__(min_value: Float = 1.0e-4)
Initialize the Linear SNR Transform.
Arguments:
min_value
Float - min vaue of SNR defaults to 1.e-4.class LinearLogInterpolatedSNRTransform(ContinuousSNRTransform)
A Linear Log space interpolated SNR schedule.
def __init__(min_value: Float = -7.0, max_value=13.5)
Initialize the Linear log space interpolated SNR Schedule from Chroma.
Arguments:
min_value
Float - The min log SNR value.max_value
Float - the max log SNR value.class DiscreteNoiseSchedule(ABC)
A base class for discrete noise schedules.
def __init__(nsteps: int, direction: TimeDirection)
Initialize the DiscreteNoiseSchedule.
Arguments:
nsteps
int - number of discrete steps.direction
TimeDirection - required this defines in which direction the scheduler was builtdef generate_schedule(nsteps: Optional[int] = None,
device: Union[str, torch.device] = "cpu",
synchronize: Optional[TimeDirection] = None) -> Tensor
Generate the noise schedule as a tensor.
Arguments:
nsteps
Optional[int] - Number of time steps. If None, uses the value from initialization.device
Optional[str] - Device to place the schedule on (default is "cpu").synchronize
Optional[str] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction,def calculate_derivative(
nsteps: Optional[int] = None,
device: Union[str, torch.device] = "cpu",
synchronize: Optional[TimeDirection] = None) -> Tensor
Calculate the time derivative of the schedule.
Arguments:
nsteps
Optional[int] - Number of time steps. If None, uses the value from initialization.device
Optional[str] - Device to place the schedule on (default is "cpu").synchronize
Optional[str] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction,Returns:
Tensor
- A tensor representing the time derivative of the schedule.Raises:
NotImplementedError
- If the derivative calculation is not implemented for this schedule.class DiscreteCosineNoiseSchedule(DiscreteNoiseSchedule)
A cosine discrete noise schedule.
def __init__(nsteps: int, nu: Float = 1.0, s: Float = 0.008)
Initialize the CosineNoiseSchedule.
Arguments:
nsteps
int - Number of discrete steps.nu
Optional[Float] - Hyperparameter for the cosine schedule exponent (default is 1.0).s
Optional[Float] - Hyperparameter for the cosine schedule shift (default is 0.008).class DiscreteLinearNoiseSchedule(DiscreteNoiseSchedule)
A linear discrete noise schedule.
def __init__(nsteps: int, beta_start: Float = 1e-4, beta_end: Float = 0.02)
Initialize the CosineNoiseSchedule.
Arguments:
nsteps
Optional[int] - Number of time steps. If None, uses the value from initialization.beta_start
Optional[int] - starting beta value. Defaults to 1e-4.beta_end
Optional[int] - end beta value. Defaults to 0.02.class ContinuousExpNoiseTransform(ABC)
A base class for continuous schedules.
alpha = exp(- sigma) where 1 - alpha controls the masking fraction.
def __init__(direction: TimeDirection)
Initialize the DiscreteNoiseSchedule.
Arguments:
direction : TimeDirection, required this defines in which direction the scheduler was built
def calculate_sigma(t: Tensor,
device: Union[str, torch.device] = "cpu",
synchronize: Optional[TimeDirection] = None) -> Tensor
Calculate the sigma for the given time steps.
Arguments:
t
Tensor - The input tensor representing the time steps, with values ranging from 0 to 1.device
Optional[str] - The device to place the schedule on. Defaults to "cpu".synchronize
optional[TimeDirection] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction,Returns:
Tensor
- A tensor representing the sigma values for the given time steps.Raises:
ValueError
- If the input time steps exceed the maximum allowed value of 1.def sigma_to_alpha(sigma: Tensor) -> Tensor
Converts sigma to alpha values by alpha = exp(- sigma).
Arguments:
sigma
Tensor - The input sigma tensor.Returns:
Tensor
- A tensor containing the alpha values.class CosineExpNoiseTransform(ContinuousExpNoiseTransform)
A cosine Exponential noise schedule.
def __init__(eps: Float = 1.0e-3)
Initialize the CosineNoiseSchedule.
Arguments:
eps
Float - small number to prevent numerical issues.def d_dt_sigma(t: Tensor, device: Union[str, torch.device] = "cpu") -> Tensor
Compute the derivative of sigma with respect to time.
Arguments:
t
Tensor - The input tensor representing the time steps.device
Optional[str] - The device to place the schedule on. Defaults to "cpu".Returns:
Tensor
- A tensor representing the derivative of sigma with respect to time.Notes:
The derivative of sigma as a function of time is given by:
d/dt sigma(t) = d/dt (-log(cos(t * pi / 2) + eps))
Using the chain rule, we get:
d/dt sigma(t) = (-1 / (cos(t * pi / 2) + eps)) * (-sin(t * pi / 2) * pi / 2)
This is the derivative that is computed and returned by this method.
class LogLinearExpNoiseTransform(ContinuousExpNoiseTransform)
A log linear exponential schedule.
def __init__(eps: Float = 1.0e-3)
Initialize the CosineNoiseSchedule.
Arguments:
eps
Float - small value to prevent numerical issues.def d_dt_sigma(t: Tensor, device: Union[str, torch.device] = "cpu") -> Tensor
Compute the derivative of sigma with respect to time.
Arguments:
t
Tensor - The input tensor representing the time steps.device
Optional[str] - The device to place the schedule on. Defaults to "cpu".Returns:
Tensor
- A tensor representing the derivative of sigma with respect to time.class TimeDirection(Enum)
Enum for the direction of the noise schedule.
Noise(0) --> Data(1)
Noise(1) --> Data(0)
class InferenceSchedule(ABC)
A base class for inference time schedules.
def __init__(nsteps: int,
min_t: Float = 0,
padding: Float = 0,
dilation: Float = 0,
direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
device: Union[str, torch.device] = "cpu")
Initialize the InferenceSchedule.
Arguments:
nsteps
int - Number of time steps.min_t
Float - minimum time value defaults to 0.padding
Float - padding time value defaults to 0.dilation
Float - dilation time value defaults to 0 ie the number of replicates.direction
Optional[str] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).device
Optional[str] - Device to place the schedule on (default is "cpu").@abstractmethod
def generate_schedule(
nsteps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None) -> Tensor
Generate the time schedule as a tensor.
Arguments:
nsteps
Optioanl[int] - Number of time steps. If None, uses the value from initialization.device
Optional[str] - Device to place the schedule on (default is "cpu").def pad_time(n_samples: int,
scalar_time: Float,
device: Optional[Union[str, torch.device]] = None) -> Tensor
Creates a tensor of shape (n_samples,) filled with a scalar time value.
Arguments:
n_samples
int - The desired dimension of the output tensor.scalar_time
Float - The scalar time value to fill the tensor with.Returns:
Tensor
- A tensor of shape (n_samples,) filled with the scalar time value.class ContinuousInferenceSchedule(InferenceSchedule)
A base class for continuous time inference schedules.
def __init__(nsteps: int,
inclusive_end: bool = False,
min_t: Float = 0,
padding: Float = 0,
dilation: Float = 0,
direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
device: Union[str, torch.device] = "cpu")
Initialize the ContinuousInferenceSchedule.
Arguments:
nsteps
int - Number of time steps.inclusive_end
bool - If True, include the end value (1.0) in the schedule otherwise ends at 1.0-1/nsteps (default is False).min_t
Float - minimum time value defaults to 0.padding
Float - padding time value defaults to 0.dilation
Float - dilation time value defaults to 0 ie the number of replicates.direction
Optional[str] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).device
Optional[str] - Device to place the schedule on (default is "cpu").def discretize(nsteps: Optional[int] = None,
schedule: Optional[Tensor] = None,
device: Optional[Union[str, torch.device]] = None) -> Tensor
Discretize the time schedule into a list of time deltas.
Arguments:
nsteps
Optioanl[int] - Number of time steps. If None, uses the value from initialization.schedule
Optional[Tensor] - Time scheudle if None will generate it with generate_schedule.device
Optional[str] - Device to place the schedule on (default is "cpu").Returns:
Tensor
- A tensor of time deltas.class DiscreteInferenceSchedule(InferenceSchedule)
A base class for discrete time inference schedules.
def discretize(nsteps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None) -> Tensor
Discretize the time schedule into a list of time deltas.
Arguments:
nsteps
Optioanl[int] - Number of time steps. If None, uses the value from initialization.device
Optional[str] - Device to place the schedule on (default is "cpu").Returns:
Tensor
- A tensor of time deltas.class DiscreteLinearInferenceSchedule(DiscreteInferenceSchedule)
A linear time schedule for discrete time inference.
def __init__(nsteps: int,
min_t: Float = 0,
padding: Float = 0,
dilation: Float = 0,
direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
device: Union[str, torch.device] = "cpu")
Initialize the DiscreteLinearInferenceSchedule.
Arguments:
nsteps
int - Number of time steps.min_t
Float - minimum time value defaults to 0.padding
Float - padding time value defaults to 0.dilation
Float - dilation time value defaults to 0 ie the number of replicates.direction
Optional[str] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).device
Optional[str] - Device to place the schedule on (default is "cpu").def generate_schedule(
nsteps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None) -> Tensor
Generate the linear time schedule as a tensor.
Arguments:
nsteps
Optional[int] - Number of time steps. If None uses the value from initialization.device
Optional[str] - Device to place the schedule on (default is "cpu").Returns:
Tensor
- A tensor of time steps.Tensor
- A tensor of time steps.class LinearInferenceSchedule(ContinuousInferenceSchedule)
A linear time schedule for continuous time inference.
def __init__(nsteps: int,
inclusive_end: bool = False,
min_t: Float = 0,
padding: Float = 0,
dilation: Float = 0,
direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
device: Union[str, torch.device] = "cpu")
Initialize the LinearInferenceSchedule.
Arguments:
nsteps
int - Number of time steps.inclusive_end
bool - If True, include the end value (1.0) in the schedule otherwise ends at 1.0-1/nsteps (default is False).min_t
Float - minimum time value defaults to 0.padding
Float - padding time value defaults to 0.dilation
Float - dilation time value defaults to 0 ie the number of replicates.direction
Optional[str] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).device
Optional[str] - Device to place the schedule on (default is "cpu").def generate_schedule(
nsteps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None) -> Tensor
Generate the linear time schedule as a tensor.
Arguments:
nsteps
Optional[int] - Number of time steps. If None uses the value from initialization.device
Optional[str] - Device to place the schedule on (default is "cpu").Returns:
Tensor
- A tensor of time steps.class PowerInferenceSchedule(ContinuousInferenceSchedule)
A power time schedule for inference, where time steps are generated by raising a uniform schedule to a specified power.
def __init__(nsteps: int,
inclusive_end: bool = False,
min_t: Float = 0,
padding: Float = 0,
dilation: Float = 0,
exponent: Float = 1.0,
direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
device: Union[str, torch.device] = "cpu")
Initialize the PowerInferenceSchedule.
Arguments:
nsteps
int - Number of time steps.inclusive_end
bool - If True, include the end value (1.0) in the schedule otherwise ends at <1.0 (default is False).min_t
Float - minimum time value defaults to 0.padding
Float - padding time value defaults to 0.dilation
Float - dilation time value defaults to 0 ie the number of replicates.exponent
Float - Power parameter defaults to 1.0.direction
Optional[str] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).device
Optional[str] - Device to place the schedule on (default is "cpu").def generate_schedule(
nsteps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None) -> Tensor
Generate the power time schedule as a tensor.
Arguments:
nsteps
Optional[int] - Number of time steps. If None uses the value from initialization.device
Optional[str] - Device to place the schedule on (default is "cpu").Returns:
Tensor
- A tensor of time steps.Tensor
- A tensor of time steps.class LogInferenceSchedule(ContinuousInferenceSchedule)
A log time schedule for inference, where time steps are generated by taking the logarithm of a uniform schedule.
def __init__(nsteps: int,
inclusive_end: bool = False,
min_t: Float = 0,
padding: Float = 0,
dilation: Float = 0,
exponent: Float = -2.0,
direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
device: Union[str, torch.device] = "cpu")
Initialize the LogInferenceSchedule.
Returns a log space time schedule.
Which for 100 steps with default parameters is:
tensor([0.0000, 0.0455, 0.0889, 0.1303, 0.1699, 0.2077, 0.2439, 0.2783, 0.3113,
0.3427, 0.3728, 0.4015, 0.4288, 0.4550, 0.4800, 0.5039, 0.5266, 0.5484,
0.5692, 0.5890, 0.6080, 0.6261, 0.6434, 0.6599, 0.6756, 0.6907, 0.7051,
0.7188, 0.7319, 0.7444, 0.7564, 0.7678, 0.7787, 0.7891, 0.7991, 0.8086,
0.8176, 0.8263, 0.8346, 0.8425, 0.8500, 0.8572, 0.8641, 0.8707, 0.8769,
0.8829, 0.8887, 0.8941, 0.8993, 0.9043, 0.9091, 0.9136, 0.9180, 0.9221,
0.9261, 0.9299, 0.9335, 0.9369, 0.9402, 0.9434, 0.9464, 0.9492, 0.9520,
0.9546, 0.9571, 0.9595, 0.9618, 0.9639, 0.9660, 0.9680, 0.9699, 0.9717,
0.9734, 0.9751, 0.9767, 0.9782, 0.9796, 0.9810, 0.9823, 0.9835, 0.9847,
0.9859, 0.9870, 0.9880, 0.9890, 0.9899, 0.9909, 0.9917, 0.9925, 0.9933,
0.9941, 0.9948, 0.9955, 0.9962, 0.9968, 0.9974, 0.9980, 0.9985, 0.9990,
0.9995])
Arguments:
nsteps
int - Number of time steps.inclusive_end
bool - If True, include the end value (1.0) in the schedule otherwise ends at <1.0 (default is False).min_t
Float - minimum time value defaults to 0.padding
Float - padding time value defaults to 0.dilation
Float - dilation time value defaults to 0 ie the number of replicates.exponent
Float - log space exponent parameter defaults to -2.0. The lower number the more aggressive the acceleration of 0 to 0.9 will be thus having more steps from 0.9 to 1.0.direction
Optional[str] - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).device
Optional[str] - Device to place the schedule on (default is "cpu").def generate_schedule(
nsteps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None) -> Tensor
Generate the log time schedule as a tensor.
Arguments:
nsteps
Optional[int] - Number of time steps. If None uses the value from initialization.device
Optional[str] - Device to place the schedule on (default is "cpu").class MDLM(Interpolant)
A Masked discrete Diffusion Language Model (MDLM) interpolant.
Examples:
>>> import torch
>>> from bionemo.bionemo.moco.distributions.prior.discrete.mask import DiscreteMaskedPrior
>>> from bionemo.bionemo.moco.distributions.time.uniform import UniformTimeDistribution
>>> from bionemo.bionemo.moco.interpolants.continuous_time.discrete.mdlm import MDLM
>>> from bionemo.bionemo.moco.schedules.noise.continuous_noise_transforms import CosineExpNoiseTransform
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import LinearTimeSchedule
mdlm = MDLM(
time_distribution = UniformTimeDistribution(discrete_time = False,...),
prior_distribution = DiscreteMaskedPrior(...),
noise_schedule = CosineExpNoiseTransform(...),
)
model = Model(...)
# Training
for epoch in range(1000):
data = data_loader.get(...)
time = mdlm.sample_time(batch_size)
xt = mdlm.interpolate(data, time)
logits = model(xt, time)
loss = mdlm.loss(logits, data, xt, time)
loss.backward()
# Generation
x_pred = mdlm.sample_prior(data.shape)
schedule = LinearTimeSchedule(...)
inference_time = schedule.generate_schedule()
dts = schedue.discreteize()
for t, dt in zip(inference_time, dts):
time = torch.full((batch_size,), t)
logits = model(x_pred, time)
x_pred = mdlm.step(logits, time, x_pred, dt)
return x_pred
def __init__(time_distribution: TimeDistribution,
prior_distribution: DiscreteMaskedPrior,
noise_schedule: ContinuousExpNoiseTransform,
device: str = "cpu",
rng_generator: Optional[torch.Generator] = None)
Initialize the Masked Discrete Language Model (MDLM) interpolant.
Arguments:
time_distribution
TimeDistribution - The distribution governing the time variable in the diffusion process.prior_distribution
DiscreteMaskedPrior - The prior distribution over the discrete token space, including masked tokens.noise_schedule
ContinuousExpNoiseTransform - The noise schedule defining the noise intensity as a function of time.device
str, optional - The device to use for computations. Defaults to "cpu".rng_generator
Optional[torch.Generator], optional - The random number generator for reproducibility. Defaults to None.def interpolate(data: Tensor, t: Tensor)
Get x(t) with given time t from noise and data.
Arguments:
data
Tensor - target discrete idst
Tensor - timedef forward_process(data: Tensor, t: Tensor) -> Tensor
Apply the forward process to the data at time t.
Arguments:
data
Tensor - target discrete idst
Tensor - timeReturns:
Tensor
- x(t) after applying the forward processdef loss(logits: Tensor,
target: Tensor,
xt: Tensor,
time: Tensor,
mask: Optional[Tensor] = None,
use_weight=True)
Calculate the cross-entropy loss between the model prediction and the target output.
The loss is calculated between the batch x node x class logits and the target batch x node,
considering the current state of the discrete sequence xt
at time time
.
If use_weight
is True, the loss is weighted by the reduced form of the MDLM time weight for continuous NELBO,
as specified in equation 11 of https://arxiv.org/pdf/2406.07524. This weight is proportional to the derivative
of the noise schedule with respect to time, and is used to emphasize the importance of accurate predictions at
certain times in the diffusion process.
Arguments:
logits
Tensor - The predicted output from the model, with shape batch x node x class.target
Tensor - The target output for the model prediction, with shape batch x node.xt
Tensor - The current state of the discrete sequence, with shape batch x node.time
Tensor - The time at which the loss is calculated.mask
Optional[Tensor], optional - The mask for the data point. Defaults to None.use_weight
bool, optional - Whether to use the MDLM time weight for the loss. Defaults to True.Returns:
Tensor
- The calculated loss batch tensor.def step(logits: Tensor,
t: Tensor,
xt: Tensor,
dt: Tensor,
temperature: float = 1.0) -> Tensor
Perform a single step of MDLM DDPM step.
Arguments:
logits
Tensor - The input logits.t
Tensor - The current time step.xt
Tensor - The current state.dt
Tensor - The time step increment.temperature
float - Softmax temperature defaults to 1.0.Returns:
Tensor
- The updated state.def get_num_steps_confidence(xt: Tensor, num_tokens_unmask: int = 1)
Calculate the maximum number of steps with confidence.
This method computes the maximum count of occurrences where the input tensor xt
matches the mask_index
along the last dimension (-1). The result is returned as a single float value.
Arguments:
xt
Tensor - Input tensor to evaluate against the mask index.num_tokens_unmask
int - number of tokens to unamsk at each step.Returns:
float
- The maximum number of steps with confidence (i.e., matching the mask index).def step_confidence(logits: Tensor,
xt: Tensor,
curr_step: int,
num_steps: int,
logit_temperature: float = 1.0,
randomness: float = 1.0,
confidence_temperature: float = 1.0,
num_tokens_unmask: int = 1) -> Tensor
Update the input sequence xt by sampling from the predicted logits and adding Gumbel noise.
Method taken from GenMol Lee et al. https://arxiv.org/abs/2501.06158
Arguments:
logits
- Predicted logitsxt
- Input sequencecurr_step
- Current stepnum_steps
- Total number of stepslogit_temperature
- Temperature for softmax over logitsrandomness
- Scale for Gumbel noiseconfidence_temperature
- Temperature for Gumbel confidencenum_tokens_unmask
- number of tokens to unmask each stepReturns:
Updated input sequence xt unmasking num_tokens_unmask token each step.
def step_argmax(model_out: Tensor)
Returns the index of the maximum value in the last dimension of the model output.
Arguments:
model_out
Tensor - The output of the model.Returns:
Tensor
- The index of the maximum value in the last dimension of the model output.def calculate_score(logits, x, t)
Returns score of the given sample x at time t with the corresponding model output logits.
Arguments:
logits
Tensor - The output of the model.x
Tensor - The current data point.t
Tensor - The current time.Returns:
Tensor
- The score defined in Appendix C.3 Equation 76 of MDLM.def step_self_path_planning(logits: Tensor,
xt: Tensor,
t: Tensor,
curr_step: int,
num_steps: int,
logit_temperature: float = 1.0,
randomness: float = 1.0,
confidence_temperature: float = 1.0,
score_type: Literal["confidence",
"random"] = "confidence",
fix_mask: Optional[Tensor] = None) -> Tensor
Self Path Planning (P2) Sampling from Peng et al. https://arxiv.org/html/2502.03540v1.
Arguments:
logits
Tensor - Predicted logits for sampling.xt
Tensor - Input sequence to be updated.t
Tensor - Time tensor (e.g., time steps or temporal info).curr_step
int - Current iteration in the planning process.num_steps
int - Total number of planning steps.logit_temperature
float - Temperature for logits (default: 1.0).randomness
float - Introduced randomness level (default: 1.0).confidence_temperature
float - Temperature for confidence scoring (default: 1.0).score_type
Literal["confidence", "random"] - Sampling score type (default: "confidence").fix_mask
Optional[Tensor] - inital mask where True when not a mask tokens (default: None).Returns:
Tensor
- Updated input sequence xt after iterative unmasking.def topk_lowest_masking(scores: Tensor, cutoff_len: Tensor)
Generates a mask for the lowest scoring elements up to a specified cutoff length.
Arguments:
scores
Tensor - Input scores tensor with shape (... , num_elements)cutoff_len
Tensor - Number of lowest-scoring elements to mask (per batch element)Returns:
Tensor
- Boolean mask tensor with same shape as scores
, where True
indicatescutoff_len
lowest scores.Example:
scores = torch.tensor([[0.9, 0.8, 0.1, 0.05], [0.7, 0.4, 0.3, 0.2]])
cutoff_len = 2
mask = topk_lowest_masking(scores, cutoff_len)
print(mask)
tensor([[False, False, True, True],
[False, True, True, False]])
def stochastic_sample_from_categorical(logits: Tensor,
temperature: float = 1.0,
noise_scale: float = 1.0)
Stochastically samples from a categorical distribution defined by input logits, with optional temperature and noise scaling for diverse sampling.
Arguments:
logits
Tensor - Input logits tensor with shape (... , num_categories)temperature
float, optional - Softmax temperature. Higher values produce more uniform samples. Defaults to 1.0.noise_scale
float, optional - Scale for Gumbel noise. Higher values produce more diverse samples. Defaults to 1.0.Returns:
tuple:
- tokens (LongTensor): Sampling result (category indices) with shape (... , )
- scores (Tensor): Corresponding log-softmax scores for the sampled tokens, with shape (... , )
class DiscreteFlowMatcher(Interpolant)
A Discrete Flow Model (DFM) interpolant.
def __init__(time_distribution: TimeDistribution,
prior_distribution: DiscretePriorDistribution,
device: str = "cpu",
eps: Float = 1e-5,
rng_generator: Optional[torch.Generator] = None)
Initialize the DFM interpolant.
Arguments:
time_distribution
TimeDistribution - The time distribution for the diffusion process.prior_distribution
DiscretePriorDistribution - The prior distribution for the discrete masked tokens.device
str, optional - The device to use for computations. Defaults to "cpu".eps
- small Float to prevent dividing by zero.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.def interpolate(data: Tensor, t: Tensor, noise: Tensor)
Get x(t) with given time t from noise and data.
Arguments:
data
Tensor - target discrete idst
Tensor - timenoise
- tensor noise idsdef loss(logits: Tensor,
target: Tensor,
time: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
use_weight: Bool = False)
Calculate the cross-entropy loss between the model prediction and the target output.
The loss is calculated between the batch x node x class logits and the target batch x node.
If using a masked prior please pass in the correct mask to calculate loss values on only masked states.
i.e. mask = data_mask * is_masked_state which is calculated with self.prior_dist.is_masked(xt))
If use_weight
is True, the loss is weighted by 1/(1-t) defined in equation 24 in Appndix C. of https://arxiv.org/pdf/2402.04997
Arguments:
logits
Tensor - The predicted output from the model, with shape batch x node x class.target
Tensor - The target output for the model prediction, with shape batch x node.time
Tensor - The time at which the loss is calculated.mask
Optional[Tensor], optional - The mask for the data point. Defaults to None.use_weight
bool, optional - Whether to use the DFM time weight for the loss. Defaults to True.Returns:
Tensor
- The calculated loss batch tensor.def step(logits: Tensor,
t: Tensor,
xt: Tensor,
dt: Tensor | float,
temperature: Float = 1.0,
stochasticity: Float = 1.0) -> Tensor
Perform a single step of DFM euler updates.
Arguments:
logits
Tensor - The input logits.t
Tensor - The current time step.xt
Tensor - The current state.dt
Tensor | float - The time step increment.temperature
Float, optional - The temperature for the softmax calculation. Defaults to 1.0.stochasticity
Float, optional - The stochasticity value for the step calculation. Defaults to 1.0.Returns:
Tensor
- The updated state.def step_purity(logits: Tensor,
t: Tensor,
xt: Tensor,
dt: Tensor | float,
temperature: Float = 1.0,
stochasticity: Float = 1.0) -> Tensor
Perform a single step of purity sampling.
https://github.com/jasonkyuyim/multiflow/blob/6278899970523bad29953047e7a42b32a41dc813/multiflow/data/interpolant.py#L346
Here's a high-level overview of what the function does:
TODO: check if the -1e9 and 1e-9 are small enough or using torch.inf would be better
Arguments:
logits
Tensor - The input logits.t
Tensor - The current time step.xt
Tensor - The current state.dt
Tensor - The time step increment.temperature
Float, optional - The temperature for the softmax calculation. Defaults to 1.0.stochasticity
Float, optional - The stochasticity value for the step calculation. Defaults to 1.0.Returns:
Tensor
- The updated state.def step_argmax(model_out: Tensor)
Returns the index of the maximum value in the last dimension of the model output.
Arguments:
model_out
Tensor - The output of the model.def step_simple_sample(model_out: Tensor,
temperature: float = 1.0,
num_samples: int = 1)
Samples from the model output logits. Leads to more diversity than step_argmax.
Arguments:
model_out
Tensor - The output of the model.temperature
Float, optional - The temperature for the softmax calculation. Defaults to 1.0.num_samples
int - Number of samples to returnclass OTSampler()
Sampler for Exact Mini-batch Optimal Transport Plan.
OTSampler implements sampling coordinates according to an OT plan (wrt squared Euclidean cost)
with different implementations of the plan calculation. Code is adapted from https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/optimal_transport.py
def __init__(method: str = "exact",
device: Union[str, torch.device] = "cpu",
num_threads: int = 1) -> None
Initialize the OTSampler class.
Arguments:
method
str - Choose which optimal transport solver you would like to use. Currently only support exact OT solvers (pot.emd).device
Union[str, torch.device], optional - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".num_threads
Union[int, str], optional - Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1.Raises:
ValueError
- If the OT solver is not documented.NotImplementedError
- If the OT solver is not implemented.def to_device(device: str)
Moves all internal tensors to the specified device and updates the self.device
attribute.
Arguments:
device
str - The device to move the tensors to (e.g. "cpu", "cuda:0").Notes:
This method is used to transfer the internal state of the OTSampler to a different device.
It updates the self.device
attribute to reflect the new device and moves all internal tensors to the specified device.
def sample_map(pi: Tensor,
batch_size: int,
replace: Bool = False) -> Tuple[Tensor, Tensor]
Draw source and target samples from pi $(x,z) \sim \pi$.
Arguments:
pi
Tensor - shape (bs, bs), the OT matrix between noise and data in minibatch.batch_size
int - The batch size of the minibatch.replace
bool - sampling w/ or w/o replacement from the OT plan, default to False.Returns:
Tuple
- tuple of 2 tensors, represents the indices of noise and data samples from pi.def get_ot_matrix(x0: Tensor,
x1: Tensor,
mask: Optional[Tensor] = None) -> Tensor
Compute the OT matrix between a source and a target minibatch.
Arguments:
x0
Tensor - shape (bs, *dim), noise from source minibatch.x1
Tensor - shape (bs, *dim), data from source minibatch.mask
Optional[Tensor], optional - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.Returns:
p
Tensor - shape (bs, bs), the OT matrix between noise and data in minibatch.def apply_augmentation(
x0: Tensor,
x1: Tensor,
mask: Optional[Tensor] = None,
replace: Bool = False,
sort: Optional[Literal["noise", "x0", "data", "x1"]] = "x0"
) -> Tuple[Tensor, Tensor, Optional[Tensor]]
Sample indices for noise and data in minibatch according to OT plan.
Compute the OT plan $\pi$ (wrt squared Euclidean cost) between a source and a target
minibatch and draw source and target samples from pi $(x,z) \sim \pi$.
Arguments:
x0
Tensor - shape (bs, *dim), noise from source minibatch.x1
Tensor - shape (bs, *dim), data from source minibatch.mask
Optional[Tensor], optional - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.replace
bool - sampling w/ or w/o replacement from the OT plan, default to False.sort
str - Optional Literal string to sort either x1 or x0 based on the input.Returns:
Tuple
- tuple of 2 tensors or 3 tensors if mask is used, represents the noise (plus mask) and data samples following OT plan pi.class EquivariantOTSampler()
Sampler for Mini-batch Optimal Transport Plan with cost calculated after Kabsch alignment.
EquivariantOTSampler implements sampling coordinates according to an OT plan
(wrt squared Euclidean cost after Kabsch alignment) with different implementations of the plan calculation.
def __init__(method: str = "exact",
device: Union[str, torch.device] = "cpu",
num_threads: int = 1) -> None
Initialize the OTSampler class.
Arguments:
method
str - Choose which optimal transport solver you would like to use. Currently only support exact OT solvers (pot.emd).device
Union[str, torch.device], optional - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".num_threads
Union[int, str], optional - Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1.Raises:
ValueError
- If the OT solver is not documented.NotImplementedError
- If the OT solver is not implemented.def to_device(device: str)
Moves all internal tensors to the specified device and updates the self.device
attribute.
Arguments:
device
str - The device to move the tensors to (e.g. "cpu", "cuda:0").Notes:
This method is used to transfer the internal state of the OTSampler to a different device.
It updates the self.device
attribute to reflect the new device and moves all internal tensors to the specified device.
def sample_map(pi: Tensor,
batch_size: int,
replace: Bool = False) -> Tuple[Tensor, Tensor]
Draw source and target samples from pi $(x,z) \sim \pi$.
Arguments:
pi
Tensor - shape (bs, bs), the OT matrix between noise and data in minibatch.batch_size
int - The batch size of the minibatch.replace
bool - sampling w/ or w/o replacement from the OT plan, default to False.Returns:
Tuple
- tuple of 2 tensors, represents the indices of noise and data samples from pi.def kabsch_align(target: Tensor, noise: Tensor) -> Tensor
Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise.
Arguments:
target
Tensor - shape (N, *dim), data from source minibatch.noise
Tensor - shape (N, *dim), noise from source minibatch.Returns:
R
Tensor - shape (dim, dim), the rotation matrix.def get_ot_matrix(x0: Tensor,
x1: Tensor,
mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]
Compute the OT matrix between a source and a target minibatch.
Arguments:
x0
Tensor - shape (bs, *dim), noise from source minibatch.x1
Tensor - shape (bs, *dim), data from source minibatch.mask
Optional[Tensor], optional - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.Returns:
p
Tensor - shape (bs, bs), the OT matrix between noise and data in minibatch.Rs
Tensor - shape (bs, bs, dim, dim), the rotation matrix between noise and data in minibatch.def apply_augmentation(
x0: Tensor,
x1: Tensor,
mask: Optional[Tensor] = None,
replace: Bool = False,
sort: Optional[Literal["noise", "x0", "data", "x1"]] = "x0"
) -> Tuple[Tensor, Tensor, Optional[Tensor]]
Sample indices for noise and data in minibatch according to OT plan.
Compute the OT plan $\pi$ (wrt squared Euclidean cost after Kabsch alignment) between a source and a target
minibatch and draw source and target samples from pi $(x,z) \sim \pi$.
Arguments:
x0
Tensor - shape (bs, *dim), noise from source minibatch.x1
Tensor - shape (bs, *dim), data from source minibatch.mask
Optional[Tensor], optional - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.replace
bool - sampling w/ or w/o replacement from the OT plan, default to False.sort
str - Optional Literal string to sort either x1 or x0 based on the input.Returns:
Tuple
- tuple of 2 tensors, represents the noise and data samples following OT plan pi.class KabschAugmentation()
Point-wise Kabsch alignment.
def __init__()
Initialize the KabschAugmentation instance.
Notes:
self.variable_name
) as needed.def kabsch_align(target: Tensor, noise: Tensor)
Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise.
Arguments:
target
Tensor - shape (N, *dim), data from source minibatch.noise
Tensor - shape (N, *dim), noise from source minibatch.Returns:
R
Tensor - shape (dim, dim), the rotation matrix.def batch_kabsch_align(target: Tensor, noise: Tensor)
Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise.
Arguments:
target
Tensor - shape (B, N, *dim), data from source minibatch.noise
Tensor - shape (B, N, *dim), noise from source minibatch.Returns:
R
Tensor - shape (dim, dim), the rotation matrix.def apply_augmentation(x0: Tensor,
x1: Tensor,
mask: Optional[Tensor] = None,
align_noise_to_data=True) -> Tuple[Tensor, Tensor]
Sample indices for noise and data in minibatch according to OT plan.
Compute the OT plan $\pi$ (wrt squared Euclidean cost after Kabsch alignment) between a source and a target
minibatch and draw source and target samples from pi $(x,z) \sim \pi$.
Arguments:
x0
Tensor - shape (bs, *dim), noise from source minibatch.x1
Tensor - shape (bs, *dim), data from source minibatch.mask
Optional[Tensor], optional - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.replace
bool - sampling w/ or w/o replacement from the OT plan, default to False.align_noise_to_data
bool - Direction of alignment default is True meaning it augments Noise to reduce error to Data.Returns:
Tuple
- tuple of 2 tensors, represents the noise and data samples following OT plan pi.class AugmentationType(Enum)
An enumeration representing the type ofOptimal Transport that can be used in Continuous Flow Matching.
These prediction types can be used to train neural networks for specific tasks, such as denoising, image synthesis, or time-series forecasting.
class VDM(Interpolant)
A Variational Diffusion Models (VDM) interpolant.
Examples:
>>> import torch
>>> from bionemo.bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
>>> from bionemo.bionemo.moco.distributions.time.uniform import UniformTimeDistribution
>>> from bionemo.bionemo.moco.interpolants.discrete_time.continuous.vdm import VDM
>>> from bionemo.bionemo.moco.schedules.noise.continuous_snr_transforms import CosineSNRTransform
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule
vdm = VDM(
time_distribution = UniformTimeDistribution(...),
prior_distribution = GaussianPrior(...),
noise_schedule = CosineSNRTransform(...),
)
model = Model(...)
# Training
for epoch in range(1000):
data = data_loader.get(...)
time = vdm.sample_time(batch_size)
noise = vdm.sample_prior(data.shape)
xt = vdm.interpolate(data, noise, time)
x_pred = model(xt, time)
loss = vdm.loss(x_pred, data, time)
loss.backward()
# Generation
x_pred = vdm.sample_prior(data.shape)
for t in LinearInferenceSchedule(...).generate_schedule():
time = torch.full((batch_size,), t)
x_hat = model(x_pred, time)
x_pred = vdm.step(x_hat, time, x_pred)
return x_pred
def __init__(time_distribution: TimeDistribution,
prior_distribution: PriorDistribution,
noise_schedule: ContinuousSNRTransform,
prediction_type: Union[PredictionType, str] = PredictionType.DATA,
device: Union[str, torch.device] = "cpu",
rng_generator: Optional[torch.Generator] = None)
Initializes the DDPM interpolant.
Arguments:
time_distribution
TimeDistribution - The distribution of time steps, used to sample time points for the diffusion process.prior_distribution
PriorDistribution - The prior distribution of the variable, used as the starting point for the diffusion process.noise_schedule
ContinuousSNRTransform - The schedule of noise, defining the amount of noise added at each time step.prediction_type
PredictionType, optional - The type of prediction, either "data" or another type. Defaults to "data".device
str, optional - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.def interpolate(data: Tensor, t: Tensor, noise: Tensor)
Get x(t) with given time t from noise and data.
Arguments:
data
Tensor - targett
Tensor - timenoise
Tensor - noise from prior()def forward_process(data: Tensor, t: Tensor, noise: Optional[Tensor] = None)
Get x(t) with given time t from noise and data.
Arguments:
data
Tensor - targett
Tensor - timenoise
Tensor, optional - noise from prior(). Defaults to Nonedef process_data_prediction(model_output: Tensor, sample, t)
Converts the model output to a data prediction based on the prediction type.
This conversion stems from the Progressive Distillation for Fast Sampling of Diffusion Models https://arxiv.org/pdf/2202.00512.
Given the model output and the sample, we convert the output to a data prediction based on the prediction type.
The conversion formulas are as follows:
- For "noise" prediction type: pred_data = (sample - noise_scale * model_output) / data_scale
- For "data" prediction type: pred_data = model_output
- For "v_prediction" prediction type: pred_data = data_scale * sample - noise_scale * model_output
Arguments:
model_output
Tensor - The output of the model.sample
Tensor - The input sample.t
Tensor - The time step.Returns:
The data prediction based on the prediction type.
Raises:
ValueError
- If the prediction type is not one of "noise", "data", or "v_prediction".def process_noise_prediction(model_output: Tensor, sample: Tensor, t: Tensor)
Do the same as process_data_prediction but take the model output and convert to nosie.
Arguments:
model_output
Tensor - The output of the model.sample
Tensor - The input sample.t
Tensor - The time step.Returns:
The input as noise if the prediction type is "noise".
Raises:
ValueError
- If the prediction type is not "noise".def step(model_out: Tensor,
t: Tensor,
xt: Tensor,
dt: Tensor,
mask: Optional[Tensor] = None,
center: Bool = False,
temperature: Float = 1.0)
Do one step integration.
Arguments:
model_out
Tensor - The output of the model.xt
Tensor - The current data point.t
Tensor - The current time step.dt
Tensor - The time step increment.mask
Optional[Tensor], optional - An optional mask to apply to the data. Defaults to None.center
bool - Whether to center the data. Defaults to False.temperature
Float - The temperature parameter for low temperature sampling. Defaults to 1.0.Notes:
The temperature parameter controls the trade off between diversity and sample quality.
Decreasing the temperature sharpens the sampling distribtion to focus on more likely samples.
The impact of low temperature sampling must be ablated analytically.
def score(x_hat: Tensor, xt: Tensor, t: Tensor)
Converts the data prediction to the estimated score function.
Arguments:
x_hat
tensor - The predicted data point.xt
Tensor - The current data point.t
Tensor - The time step.Returns:
The estimated score function.
def step_ddim(model_out: Tensor,
t: Tensor,
xt: Tensor,
dt: Tensor,
mask: Optional[Tensor] = None,
eta: Float = 0.0,
center: Bool = False)
Do one step of DDIM sampling.
From the ddpm equations alpha_bar = alpha2 and 1 - alpha2 = sigma**2
Arguments:
model_out
Tensor - output of the modelt
Tensor - current time stepxt
Tensor - current data pointdt
Tensor - The time step increment.mask
Optional[Tensor], optional - mask for the data point. Defaults to None.eta
Float, optional - DDIM sampling parameter. Defaults to 0.0.center
Bool, optional - whether to center the data point. Defaults to False.def set_loss_weight_fn(fn: Callable)
Sets the loss_weight attribute of the instance to the given function.
Arguments:
fn
- The function to set as the loss_weight attribute. This function should take three arguments: raw_loss, t, and weight_type.def loss_weight(raw_loss: Tensor,
t: Tensor,
weight_type: str,
dt: Float = 0.001) -> Tensor
Calculates the weight for the loss based on the given weight type.
This function computes the loss weight according to the specified weight_type
.
The available weight types are:
- "ones": uniform weight of 1.0
- "data_to_noise": derived from Equation (9) of https://arxiv.org/pdf/2202.00512
- "variational_objective_discrete": based on the variational objective, see https://arxiv.org/pdf/2202.00512
Arguments:
raw_loss
Tensor - The raw loss calculated from the model prediction and target.t
Tensor - The time step.weight_type
str - The type of weight to use. Can be "ones", "data_to_noise", or "variational_objective_discrete".dt
Float, optional - The time step increment. Defaults to 0.001.Returns:
Tensor
- The weight for the loss.Raises:
ValueError
- If the weight type is not recognized.def loss(model_pred: Tensor,
target: Tensor,
t: Tensor,
dt: Optional[Float] = 0.001,
mask: Optional[Tensor] = None,
weight_type: str = "ones")
Calculates the loss given the model prediction, target, and time.
Arguments:
model_pred
Tensor - The predicted output from the model.target
Tensor - The target output for the model prediction.t
Tensor - The time at which the loss is calculated.dt
Optional[Float], optional - The time step increment. Defaults to 0.001.mask
Optional[Tensor], optional - The mask for the data point. Defaults to None.weight_type
str, optional - The type of weight to use for the loss. Can be "ones", "data_to_noise", or "variational_objective". Defaults to "ones".Returns:
Tensor
- The calculated loss batch tensor.def step_hybrid_sde(model_out: Tensor,
t: Tensor,
xt: Tensor,
dt: Tensor,
mask: Optional[Tensor] = None,
center: Bool = False,
temperature: Float = 1.0,
equilibrium_rate: Float = 0.0) -> Tensor
Do one step integration of Hybrid Langevin-Reverse Time SDE.
See section B.3 page 37 https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf.
and https://github.com/generatebio/chroma/blob/929407c605013613941803c6113adefdccaad679/chroma/layers/structure/diffusion.py#L730
Arguments:
model_out
Tensor - The output of the model.xt
Tensor - The current data point.t
Tensor - The current time step.dt
Tensor - The time step increment.mask
Optional[Tensor], optional - An optional mask to apply to the data. Defaults to None.center
bool, optional - Whether to center the data. Defaults to False.temperature
Float, optional - The temperature parameter for low temperature sampling. Defaults to 1.0.equilibrium_rate
Float, optional - The rate of Langevin equilibration. Scales the amount of Langevin dynamics per unit time. Best values are in the range [1.0, 5.0]. Defaults to 0.0.Notes:
For all step functions that use the SDE formulation its important to note that we are moving backwards in time which corresponds to an apparent sign change.
A clear example can be seen in slide 29 https://ernestryu.com/courses/FM/diffusion1.pdf.
def step_ode(model_out: Tensor,
t: Tensor,
xt: Tensor,
dt: Tensor,
mask: Optional[Tensor] = None,
center: Bool = False,
temperature: Float = 1.0) -> Tensor
Do one step integration of ODE.
See section B page 36 https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf.
and https://github.com/generatebio/chroma/blob/929407c605013613941803c6113adefdccaad679/chroma/layers/structure/diffusion.py#L730
Arguments:
model_out
Tensor - The output of the model.xt
Tensor - The current data point.t
Tensor - The current time step.dt
Tensor - The time step increment.mask
Optional[Tensor], optional - An optional mask to apply to the data. Defaults to None.center
bool, optional - Whether to center the data. Defaults to False.temperature
Float, optional - The temperature parameter for low temperature sampling. Defaults to 1.0.class ContinuousFlowMatcher(Interpolant)
A Continuous Flow Matching interpolant.
Examples:
>>> import torch
>>> from bionemo.bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
>>> from bionemo.bionemo.moco.distributions.time.uniform import UniformTimeDistribution
>>> from bionemo.bionemo.moco.interpolants.continuous_time.continuous.continuous_flow_matching import ContinuousFlowMatcher
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule
flow_matcher = ContinuousFlowMatcher(
time_distribution = UniformTimeDistribution(...),
prior_distribution = GaussianPrior(...),
)
model = Model(...)
# Training
for epoch in range(1000):
data = data_loader.get(...)
time = flow_matcher.sample_time(batch_size)
noise = flow_matcher.sample_prior(data.shape)
data, time, noise = flow_matcher.apply_augmentation(noise, data) # Optional, only for OT
xt = flow_matcher.interpolate(data, time, noise)
flow = flow_matcher.calculate_target(data, noise)
u_pred = model(xt, time)
loss = flow_matcher.loss(u_pred, flow)
loss.backward()
# Generation
x_pred = flow_matcher.sample_prior(data.shape)
inference_sched = LinearInferenceSchedule(...)
for t in inference_sched.generate_schedule():
time = inference_sched.pad_time(x_pred.shape[0], t)
u_hat = model(x_pred, time)
x_pred = flow_matcher.step(u_hat, x_pred, time)
return x_pred
def __init__(time_distribution: TimeDistribution,
prior_distribution: PriorDistribution,
prediction_type: Union[PredictionType, str] = PredictionType.DATA,
sigma: Float = 0,
augmentation_type: Optional[Union[AugmentationType, str]] = None,
augmentation_num_threads: int = 1,
data_scale: Float = 1.0,
device: Union[str, torch.device] = "cpu",
rng_generator: Optional[torch.Generator] = None,
eps: Float = 1e-5)
Initializes the Continuous Flow Matching interpolant.
Arguments:
time_distribution
TimeDistribution - The distribution of time steps, used to sample time points for the diffusion process.prior_distribution
PriorDistribution - The prior distribution of the variable, used as the starting point for the diffusion process.prediction_type
PredictionType, optional - The type of prediction, either "flow" or another type. Defaults to PredictionType.DATA.sigma
Float, optional - The standard deviation of the Gaussian noise added to the interpolated data. Defaults to 0.augmentation_type
Optional[Union[AugmentationType, str]], optional - The type of optimal transport, if applicable. Defaults to None.augmentation_num_threads
- Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1.data_scale
Float, optional - The scale factor for the data. Defaults to 1.0.device
Union[str, torch.device], optional - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.eps
- Small float to prevent divide by zerodef apply_augmentation(x0: Tensor,
x1: Tensor,
mask: Optional[Tensor] = None,
**kwargs) -> tuple
Sample and apply the optimal transport plan between batched (and masked) x0 and x1.
Arguments:
x0
Tensor - shape (bs, *dim), noise from source minibatch.x1
Tensor - shape (bs, *dim), data from source minibatch.mask
Optional[Tensor], optional - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.**kwargs
- Additional keyword arguments to be passed to self.augmentation_sampler.apply_augmentation or handled within this method.Returns:
Tuple
- tuple of 2 tensors, represents the noise and data samples following OT plan pi.def undo_scale_data(data: Tensor) -> Tensor
Downscale the input data by the data scale factor.
Arguments:
data
Tensor - The input data to downscale.Returns:
The downscaled data.
def scale_data(data: Tensor) -> Tensor
Upscale the input data by the data scale factor.
Arguments:
data
Tensor - The input data to upscale.Returns:
The upscaled data.
def interpolate(data: Tensor, t: Tensor, noise: Tensor) -> Tensor
Get x_t with given time t from noise (x_0) and data (x_1).
Currently, we use the linear interpolation as defined in:
1. Rectified flow: https://arxiv.org/abs/2209.03003.
2. Conditional flow matching: https://arxiv.org/abs/2210.02747 (called conditional optimal transport).
Arguments:
noise
Tensor - noise from prior(), shape (batchsize, nodes, features)t
Tensor - time, shape (batchsize)data
Tensor - target, shape (batchsize, nodes, features)def calculate_target(data: Tensor,
noise: Tensor,
mask: Optional[Tensor] = None) -> Tensor
Get the target vector field at time t.
Arguments:
noise
Tensor - noise from prior(), shape (batchsize, nodes, features)data
Tensor - target, shape (batchsize, nodes, features)mask
Optional[Tensor], optional - mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.Returns:
Tensor
- The target vector field at time t.def process_vector_field_prediction(model_output: Tensor,
xt: Optional[Tensor] = None,
t: Optional[Tensor] = None,
mask: Optional[Tensor] = None)
Process the model output based on the prediction type to calculate vecotr field.
Arguments:
model_output
Tensor - The output of the model.xt
Tensor - The input sample.t
Tensor - The time step.mask
Optional[Tensor], optional - An optional mask to apply to the model output. Defaults to None.Returns:
The vector field prediction based on the prediction type.
Raises:
ValueError
- If the prediction type is not "flow" or "data".def process_data_prediction(model_output: Tensor,
xt: Optional[Tensor] = None,
t: Optional[Tensor] = None,
mask: Optional[Tensor] = None)
Process the model output based on the prediction type to generate clean data.
Arguments:
model_output
Tensor - The output of the model.xt
Tensor - The input sample.t
Tensor - The time step.mask
Optional[Tensor], optional - An optional mask to apply to the model output. Defaults to None.Returns:
The data prediction based on the prediction type.
Raises:
ValueError
- If the prediction type is not "flow".def step(model_out: Tensor,
xt: Tensor,
dt: Tensor,
t: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
center: Bool = False)
Perform a single ODE step integration using Euler method.
Arguments:
model_out
Tensor - The output of the model at the current time step.xt
Tensor - The current intermediate state.dt
Tensor - The time step size.t
Tensor, optional - The current time. Defaults to None.mask
Optional[Tensor], optional - A mask to apply to the model output. Defaults to None.center
Bool, optional - Whether to center the output. Defaults to False.Returns:
x_next
Tensor - The updated state of the system after the single step, x_(t+dt).Notes:
clean
method is called on the updated state before it is returned.def step_score_stochastic(model_out: Tensor,
xt: Tensor,
dt: Tensor,
t: Tensor,
mask: Optional[Tensor] = None,
gt_mode: str = "tan",
gt_p: Float = 1.0,
gt_clamp: Optional[Float] = None,
score_temperature: Float = 1.0,
noise_temperature: Float = 1.0,
t_lim_ode: Float = 0.99,
center: Bool = False)
Perform a single SDE step integration using a score-based Langevin update.
d x_t = [v(x_t, t) + g(t) * s(x_t, t) * score_temperature] dt + \sqrt{2 * g(t) * noise_temperature} dw_t.
Arguments:
model_out
Tensor - The output of the model at the current time step.xt
Tensor - The current intermediate state.dt
Tensor - The time step size.t
Tensor, optional - The current time. Defaults to None.mask
Optional[Tensor], optional - A mask to apply to the model output. Defaults to None.gt_mode
str, optional - The mode for the gt function. Defaults to "tan".gt_p
Float, optional - The parameter for the gt function. Defaults to 1.0.gt_clamp
- (Float, optional): Upper limit of gt term. Defaults to None.score_temperature
Float, optional - The temperature for the score part of the step. Defaults to 1.0.noise_temperature
Float, optional - The temperature for the stochastic part of the step. Defaults to 1.0.t_lim_ode
Float, optional - The time limit for the ODE step. Defaults to 0.99.center
Bool, optional - Whether to center the output. Defaults to False.Returns:
x_next
Tensor - The updated state of the system after the single step, x_(t+dt).Notes:
clean
method is called on the updated state before it is returned.def loss(model_pred: Tensor,
target: Tensor,
t: Optional[Tensor] = None,
xt: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
target_type: Union[PredictionType, str] = PredictionType.DATA)
Calculate the loss given the model prediction, data sample, time, and mask.
If target_type is FLOW loss = ||v_hat - (x1-x0)||2
If target_type is DATA loss = ||x1_hat - x1||2 * 1 / (1 - t)**2 as the target vector field = x1 - x0 = (1/(1-t)) * x1 - xt where xt = tx1 - (1-t)x0.
This functions supports any cominbation of prediction_type and target_type in {DATA, FLOW}.
Arguments:
model_pred
Tensor - The predicted output from the model.target
Tensor - The target output for the model prediction.t
Optional[Tensor], optional - The time for the model prediction. Defaults to None.xt
Optional[Tensor], optional - The interpolated data. Defaults to None.mask
Optional[Tensor], optional - The mask for the data point. Defaults to None.target_type
PredictionType, optional - The type of the target output. Defaults to PredictionType.DATA.Returns:
Tensor
- The calculated loss batch tensor.def vf_to_score(x_t: Tensor, v: Tensor, t: Tensor) -> Tensor
From Geffner et al. Computes score of noisy density given the vector field learned by flow matching.
With our interpolation scheme these are related by
v(x_t, t) = (1 / t) (x_t + scale_ref ** 2 * (1 - t) * s(x_t, t)),
or equivalently,
s(x_t, t) = (t * v(x_t, t) - x_t) / (scale_ref ** 2 * (1 - t)).
with scale_ref = 1
Arguments:
x_t
- Noisy sample, shape [*, dim]v
- Vector field, shape [*, dim]t
- Interpolation time, shape [*] (must be < 1)Returns:
Score of intermediate density, shape [*, dim].
def get_gt(t: Tensor,
mode: str = "tan",
param: float = 1.0,
clamp_val: Optional[float] = None,
eps: float = 1e-2) -> Tensor
From Geffner et al. Computes gt for different modes.
Arguments:
t
- times where we'll evaluate, covers [0, 1), shape [nsteps]mode
- "us" or "tan"param
- parameterized transformationclamp_val
- value to clamp gt, no clamping if Noneeps
- small value leave as it isclass BatchDataAugmentation()
Facilitates the creation of batch augmentation objects based on specified optimal transport types.
Arguments:
device
str - The device to use for computations (e.g., 'cpu', 'cuda').num_threads
int - The number of threads to utilize.def __init__(device, num_threads)
Initializes a BatchAugmentation instance.
Arguments:
device
str - Device for computation.num_threads
int - Number of threads to use.def create(method_type: AugmentationType)
Creates a batch augmentation object of the specified type.
Arguments:
method_type
AugmentationType - The type of optimal transport method.Returns:
The augmentation object if the type is supported, otherwise None.
class D3PM(Interpolant)
A Discrete Denoising Diffusion Probabilistic Model (D3PM) interpolant.
def __init__(time_distribution: TimeDistribution,
prior_distribution: DiscretePriorDistribution,
noise_schedule: DiscreteNoiseSchedule,
device: str = "cpu",
last_time_idx: int = 0,
rng_generator: Optional[torch.Generator] = None)
Initializes the D3PM interpolant.
Arguments:
time_distribution
TimeDistribution - The distribution of time steps, used to sample time points for the diffusion process.prior_distribution
PriorDistribution - The prior distribution of the variable, used as the starting point for the diffusion process.noise_schedule
DiscreteNoiseSchedule - The schedule of noise, defining the amount of noise added at each time step.device
str, optional - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".last_time_idx
int, optional - The last time index to consider in the interpolation process. Defaults to 0.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.def interpolate(data: Tensor, t: Tensor)
Interpolate using discrete interpolation method.
This method implements Equation 2 from the D3PM paper (https://arxiv.org/pdf/2107.03006), which
calculates the interpolated discrete state xt
at time t
given the input data and noise
via q(xt|x0) = Cat(xt; p = x0*Qt_bar).
Arguments:
data
Tensor - The input data to be interpolated.t
Tensor - The time step at which to interpolate.Returns:
Tensor
- The interpolated discrete state xt
at time t
.def forward_process(data: Tensor, t: Tensor) -> Tensor
Apply the forward process to the data at time t.
Arguments:
data
Tensor - target discrete idst
Tensor - timeReturns:
Tensor
- x(t) after applying the forward processdef step(model_out: Tensor,
t: Tensor,
xt: Tensor,
mask: Optional[Tensor] = None,
temperature: Float = 1.0,
model_out_is_logits: bool = True)
Perform a single step in the discrete interpolant method, transitioning from the current discrete state xt
at time t
to the next state.
This step involves:
model_out
and the current state xt
at time t
.Arguments:
model_out
Tensor - The output of the model at the current time step, which is used to compute the predicted q-posterior logits.t
Tensor - The current time step, which is used to index into the transition matrices and compute the predicted q-posterior logits.xt
Tensor - The current discrete state at time t
, which is used to compute the predicted q-posterior logits and sample the next state.mask
Optional[Tensor], optional - An optional mask to apply to the next state, which can be used to mask out certain tokens or regions. Defaults to None.temperature
Float, optional - The temperature to use for the Gumbel-Softmax trick, which controls the randomness of the sampling process. Defaults to 1.0.model_out_is_logits
bool, optional - A flag indicating whether the model output is already in logits form. If True, the output is assumed to be logits; otherwise, it is converted to logits. Defaults to True.Returns:
Tensor
- The next discrete state at time t-1
.def loss(logits: Tensor,
target: Tensor,
xt: Tensor,
time: Tensor,
mask: Optional[Tensor] = None,
vb_scale: Float = 0.0)
Calculate the cross-entropy loss between the model prediction and the target output.
The loss is calculated between the batch x node x class logits and the target batch x node. If a mask is provided, the loss is
calculated only for the non-masked elements. Additionally, if vb_scale is greater than 0, the variational lower bound loss is
calculated and added to the total loss.
Arguments:
logits
Tensor - The predicted output from the model, with shape batch x node x class.target
Tensor - The target output for the model prediction, with shape batch x node.xt
Tensor - The current data point.time
Tensor - The time at which the loss is calculated.mask
Optional[Tensor], optional - The mask for the data point. Defaults to None.vb_scale
Float, optional - The scale factor for the variational lower bound loss. Defaults to 0.0.Returns:
Tensor
- The calculated loss tensor. If aggregate is True, the loss and variational lower bound loss are aggregated andclass DDPM(Interpolant)
A Denoising Diffusion Probabilistic Model (DDPM) interpolant.
Examples:
>>> import torch
>>> from bionemo.bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
>>> from bionemo.bionemo.moco.distributions.time.uniform import UniformTimeDistribution
>>> from bionemo.bionemo.moco.interpolants.discrete_time.continuous.ddpm import DDPM
>>> from bionemo.bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteCosineNoiseSchedule
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import DiscreteLinearInferenceSchedule
ddpm = DDPM(
time_distribution = UniformTimeDistribution(discrete_time = True,...),
prior_distribution = GaussianPrior(...),
noise_schedule = DiscreteCosineNoiseSchedule(...),
)
model = Model(...)
# Training
for epoch in range(1000):
data = data_loader.get(...)
time = ddpm.sample_time(batch_size)
noise = ddpm.sample_prior(data.shape)
xt = ddpm.interpolate(data, noise, time)
x_pred = model(xt, time)
loss = ddpm.loss(x_pred, data, time)
loss.backward()
# Generation
x_pred = ddpm.sample_prior(data.shape)
for t in DiscreteLinearTimeSchedule(...).generate_schedule():
time = torch.full((batch_size,), t)
x_hat = model(x_pred, time)
x_pred = ddpm.step(x_hat, time, x_pred)
return x_pred
def __init__(time_distribution: TimeDistribution,
prior_distribution: PriorDistribution,
noise_schedule: DiscreteNoiseSchedule,
prediction_type: Union[PredictionType, str] = PredictionType.DATA,
device: Union[str, torch.device] = "cpu",
last_time_idx: int = 0,
rng_generator: Optional[torch.Generator] = None)
Initializes the DDPM interpolant.
Arguments:
time_distribution
TimeDistribution - The distribution of time steps, used to sample time points for the diffusion process.prior_distribution
PriorDistribution - The prior distribution of the variable, used as the starting point for the diffusion process.noise_schedule
DiscreteNoiseSchedule - The schedule of noise, defining the amount of noise added at each time step.prediction_type
PredictionType - The type of prediction, either "data" or another type. Defaults to "data".device
str - The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".last_time_idx
int, optional - The last time index for discrete time. Set to 0 if discrete time is T-1, ..., 0 or 1 if T, ..., 1. Defaults to 0.rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.@property
def forward_data_schedule() -> torch.Tensor
Returns the forward data schedule.
@property
def forward_noise_schedule() -> torch.Tensor
Returns the forward noise schedule.
@property
def reverse_data_schedule() -> torch.Tensor
Returns the reverse data schedule.
@property
def reverse_noise_schedule() -> torch.Tensor
Returns the reverse noise schedule.
@property
def log_var() -> torch.Tensor
Returns the log variance.
@property
def alpha_bar() -> torch.Tensor
Returns the alpha bar values.
@property
def alpha_bar_prev() -> torch.Tensor
Returns the previous alpha bar values.
def interpolate(data: Tensor, t: Tensor, noise: Tensor)
Get x(t) with given time t from noise and data.
Arguments:
data
Tensor - targett
Tensor - timenoise
Tensor - noise from prior()def forward_process(data: Tensor, t: Tensor, noise: Optional[Tensor] = None)
Get x(t) with given time t from noise and data.
Arguments:
data
Tensor - targett
Tensor - timenoise
Tensor, optional - noise from prior(). Defaults to None.def process_data_prediction(model_output: Tensor, sample: Tensor, t: Tensor)
Converts the model output to a data prediction based on the prediction type.
This conversion stems from the Progressive Distillation for Fast Sampling of Diffusion Models https://arxiv.org/pdf/2202.00512.
Given the model output and the sample, we convert the output to a data prediction based on the prediction type.
The conversion formulas are as follows:
- For "noise" prediction type: pred_data = (sample - noise_scale * model_output) / data_scale
- For "data" prediction type: pred_data = model_output
- For "v_prediction" prediction type: pred_data = data_scale * sample - noise_scale * model_output
Arguments:
model_output
Tensor - The output of the model.sample
Tensor - The input sample.t
Tensor - The time step.Returns:
The data prediction based on the prediction type.
Raises:
ValueError
- If the prediction type is not one of "noise", "data", or "v_prediction".def process_noise_prediction(model_output, sample, t)
Do the same as process_data_prediction but take the model output and convert to nosie.
Arguments:
model_output
- The output of the model.sample
- The input sample.t
- The time step.Returns:
The input as noise if the prediction type is "noise".
Raises:
ValueError
- If the prediction type is not "noise".def calculate_velocity(data: Tensor, t: Tensor, noise: Tensor) -> Tensor
Calculate the velocity term given the data, time step, and noise.
Arguments:
data
Tensor - The input data.t
Tensor - The current time step.noise
Tensor - The noise term.Returns:
Tensor
- The calculated velocity term.@torch.no_grad()
def step(model_out: Tensor,
t: Tensor,
xt: Tensor,
mask: Optional[Tensor] = None,
center: Bool = False,
temperature: Float = 1.0)
Do one step integration.
Arguments:
model_out
Tensor - The output of the model.t
Tensor - The current time step.xt
Tensor - The current data point.mask
Optional[Tensor], optional - An optional mask to apply to the data. Defaults to None.center
bool, optional - Whether to center the data. Defaults to False.temperature
Float, optional - The temperature parameter for low temperature sampling. Defaults to 1.0.Notes:
The temperature parameter controls the level of randomness in the sampling process. A temperature of 1.0 corresponds to standard diffusion sampling, while lower temperatures (e.g. 0.5, 0.2) result in less random and more deterministic samples. This can be useful for tasks that require more control over the generation process.
Note for discrete time we sample from [T-1, ..., 1, 0] for T steps so we sample t = 0 hence the mask.
For continuous time we start from [1, 1 -dt, ..., dt] for T steps where s = t - 1 when t = 0 i.e dt is then 0
def step_noise(model_out: Tensor,
t: Tensor,
xt: Tensor,
mask: Optional[Tensor] = None,
center: Bool = False,
temperature: Float = 1.0)
Do one step integration.
Arguments:
model_out
Tensor - The output of the model.t
Tensor - The current time step.xt
Tensor - The current data point.mask
Optional[Tensor], optional - An optional mask to apply to the data. Defaults to None.center
bool, optional - Whether to center the data. Defaults to False.temperature
Float, optional - The temperature parameter for low temperature sampling. Defaults to 1.0.Notes:
The temperature parameter controls the level of randomness in the sampling process.
A temperature of 1.0 corresponds to standard diffusion sampling, while lower temperatures (e.g. 0.5, 0.2)
result in less random and more deterministic samples. This can be useful for tasks
that require more control over the generation process.
Note for discrete time we sample from [T-1, ..., 1, 0] for T steps so we sample t = 0 hence the mask.
For continuous time we start from [1, 1 -dt, ..., dt] for T steps where s = t - 1 when t = 0 i.e dt is then 0
def score(x_hat: Tensor, xt: Tensor, t: Tensor)
Converts the data prediction to the estimated score function.
Arguments:
x_hat
Tensor - The predicted data point.xt
Tensor - The current data point.t
Tensor - The time step.Returns:
The estimated score function.
def step_ddim(model_out: Tensor,
t: Tensor,
xt: Tensor,
mask: Optional[Tensor] = None,
eta: Float = 0.0,
center: Bool = False)
Do one step of DDIM sampling.
Arguments:
model_out
Tensor - output of the modelt
Tensor - current time stepxt
Tensor - current data pointmask
Optional[Tensor], optional - mask for the data point. Defaults to None.eta
Float, optional - DDIM sampling parameter. Defaults to 0.0.center
Bool, optional - whether to center the data point. Defaults to False.def set_loss_weight_fn(fn)
Sets the loss_weight attribute of the instance to the given function.
Arguments:
fn
- The function to set as the loss_weight attribute. This function should take three arguments: raw_loss, t, and weight_type.def loss_weight(raw_loss: Tensor, t: Optional[Tensor],
weight_type: str) -> Tensor
Calculates the weight for the loss based on the given weight type.
These data_to_noise loss weights is derived in Equation (9) of https://arxiv.org/pdf/2202.00512.
Arguments:
raw_loss
Tensor - The raw loss calculated from the model prediction and target.t
Tensor - The time step.weight_type
str - The type of weight to use. Can be "ones" or "data_to_noise" or "noise_to_data".Returns:
Tensor
- The weight for the loss.Raises:
ValueError
- If the weight type is not recognized.def loss(model_pred: Tensor,
target: Tensor,
t: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
weight_type: Literal["ones", "data_to_noise",
"noise_to_data"] = "ones")
Calculate the loss given the model prediction, data sample, and time.
The default weight_type is "ones" meaning no change / multiplying by all ones.
data_to_noise is available to scale the data MSE loss into the appropriate loss that is theoretically equivalent
to noise prediction. noise_to_data is provided for a similar reason for completeness.
Arguments:
model_pred
Tensor - The predicted output from the model.target
Tensor - The target output for the model prediction.t
Tensor - The time at which the loss is calculated.mask
Optional[Tensor], optional - The mask for the data point. Defaults to None.weight_type
Literal["ones", "data_to_noise", "noise_to_data"] - The type of weight to use for the loss. Defaults to "ones".Returns:
Tensor
- The calculated loss batch tensor.def safe_index(tensor: Tensor, index: Tensor, device: Optional[torch.device])
Safely indexes a tensor using a given index and returns the result on a specified device.
Note can implement forcing with return tensor[index.to(tensor.device)].to(device) but has costly migration.
Arguments:
tensor
Tensor - The tensor to be indexed.index
Tensor - The index to use for indexing the tensor.device
torch.device - The device on which the result should be returned.Returns:
Tensor
- The indexed tensor on the specified device.Raises:
ValueError
- If tensor, index are not all on the same device.def string_to_enum(value: Union[str, AnyEnum],
enum_type: Type[AnyEnum]) -> AnyEnum
Converts a string to an enum value of the specified type. If the input is already an enum instance, it is returned as-is.
Arguments:
value
Union[str, E] - The string to convert or an existing enum instance.enum_type
Type[E] - The enum type to convert to.Returns:
E
- The corresponding enum value.Raises:
ValueError
- If the string does not correspond to any enum member.def pad_like(source: Tensor, target: Tensor) -> Tensor
Pads the dimensions of the source tensor to match the dimensions of the target tensor.
Arguments:
source
Tensor - The tensor to be padded.target
Tensor - The tensor that the source tensor should match in dimensions.Returns:
Tensor
- The padded source tensor.Raises:
ValueError
- If the source tensor has more dimensions than the target tensor.Example:
source = torch.tensor([1, 2, 3]) # shape: (3,)
target = torch.tensor([[1, 2], [4, 5], [7, 8]]) # shape: (3, 2)
padded_source = pad_like(source, target) # shape: (3, 1)
class PredictionType(Enum)
An enumeration representing the type of prediction a Denoising Diffusion Probabilistic Model (DDPM) can be used for.
DDPMs are versatile models that can be utilized for various prediction tasks, including:
These prediction types can be used to train neural networks for specific tasks, such as denoising, image synthesis, or time-series forecasting.
class Interpolant(ABC)
An abstract base class representing an Interpolant.
This class serves as a foundation for creating interpolants that can be used
in various applications, providing a basic structure and interface for
interpolation-related operations.
def __init__(time_distribution: TimeDistribution,
prior_distribution: PriorDistribution,
device: Union[str, torch.device] = "cpu",
rng_generator: Optional[torch.Generator] = None)
Initializes the Interpolant class.
Arguments:
time_distribution
TimeDistribution - The distribution of time steps.prior_distribution
PriorDistribution - The prior distribution of the variable.device
Union[str, torch.device], optional - The device on which to operate. Defaults to "cpu".rng_generator
- An optional :class:torch.Generator
for reproducible sampling. Defaults to None.@abstractmethod
def interpolate(*args, **kwargs) -> Tensor
Get x(t) with given time t from noise and data.
Interpolate between x0 and x1 at the given time t.
@abstractmethod
def step(*args, **kwargs) -> Tensor
Do one step integration.
def general_step(method_name: str, kwargs: dict)
Calls a step method of the class by its name, passing the provided keyword arguments.
Arguments:
method_name
str - The name of the step method to call.kwargs
dict - Keyword arguments to pass to the step method.Returns:
The result of the step method call.
Raises:
ValueError
- If the provided method name does not start with 'step'.Exception
- If the step method call fails. The error message includes a list of available step methods.Notes:
This method allows for dynamic invocation of step methods, providing flexibility in the class's usage.
def sample_prior(*args, **kwargs) -> Tensor
Sample from prior distribution.
This method generates a sample from the prior distribution specified by the
prior_distribution
attribute.
Returns:
Tensor
- The generated sample from the prior distribution.def sample_time(*args, **kwargs) -> Tensor
Sample from time distribution.
def to_device(device: str)
Moves all internal tensors to the specified device and updates the self.device
attribute.
Arguments:
device
str - The device to move the tensors to (e.g. "cpu", "cuda:0").Notes:
This method is used to transfer the internal state of the DDPM interpolant to a different device.
It updates the self.device
attribute to reflect the new device and moves all internal tensors to the specified device.
def clean_mask_center(data: Tensor,
mask: Optional[Tensor] = None,
center: Bool = False) -> Tensor
Returns a clean tensor that has been masked and/or centered based on the function arguments.
Arguments:
data
- The input data with shape (..., nodes, features).mask
- An optional mask to apply to the data with shape (..., nodes). If provided, it is used to calculate the CoM. Defaults to None.center
- A boolean indicating whether to center the data around the calculated CoM. Defaults to False.Returns:
The data with shape (..., nodes, features) either centered around the CoM if center
is True or unchanged if center
is False.
@contextmanager
def parallel_context(rank: int = 0, world_size: int = 1)
Context manager for torch distributed testing.
Sets up and cleans up the distributed environment, including the device mesh.
Arguments:
rank
int - The rank of the process. Defaults to 0.world_size
int - The world size of the distributed environment. Defaults to 1.Yields:
None
def clean_up_distributed() -> None
Cleans up the distributed environment.
Destroys the process group and empties the CUDA cache.
Arguments:
None
Returns:
None