# -*- coding: utf-8 -*-
import math
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from scipy import optimize
from sklearn.linear_model._base import LinearModel, LinearClassifierMixin
from . import ops
from .kernels import kernels
from .utils import spherical_kmeans, gaussian_filter_1d, normalize_, flip, EPS
class CKNLayer(nn.Conv1d):
def __init__(self, in_channels, out_channels, filter_size,
padding=0, dilation=1, groups=1, subsampling=1,
kernel_func="exp", kernel_args=[0.5],
kernel_args_trainable=False):
if padding == "SAME":
padding = (filter_size - 1)//2
else:
padding = 0
super(CKNLayer, self).__init__(in_channels, out_channels, filter_size,
stride=1, padding=padding,
dilation=dilation, groups=groups,
bias=False)
self.subsampling = subsampling
self.filter_size = filter_size
self.patch_dim = self.in_channels * self.filter_size
self._need_lintrans_computed = True
self.kernel_args_trainable = kernel_args_trainable
if isinstance(kernel_args, (int, float)):
kernel_args = [kernel_args]
if kernel_func == "exp" or kernel_func == "add_exp":
kernel_args = [1./kernel_arg ** 2 for kernel_arg in kernel_args]
self.kernel_args = kernel_args
if kernel_args_trainable:
self.kernel_args = nn.ParameterList([nn.Parameter(torch.Tensor(
[kernel_arg])) for kernel_arg in kernel_args])
kernel_func = kernels[kernel_func]
self.kappa = lambda x: kernel_func(x, *self.kernel_args)
ones = torch.ones(1, self.in_channels // self.groups, self.filter_size)
self.register_buffer("ones", ones)
self.init_pooling_filter()
self.register_buffer("lintrans",
torch.Tensor(out_channels, out_channels))
def init_pooling_filter(self):
if self.subsampling <= 1:
return
size = 2 * self.subsampling - 1
pooling_filter = gaussian_filter_1d(size)
pooling_filter = pooling_filter.expand(self.out_channels, 1, size)
self.register_buffer("pooling_filter", pooling_filter)
def train(self, mode=True):
super(CKNLayer, self).train(mode)
if self.training is True:
self._need_lintrans_computed = True
def _compute_lintrans(self):
"""Compute the linear transformation factor kappa(ZtZ)^(-1/2)
Returns:
lintrans: out_channels x out_channels
"""
if not self._need_lintrans_computed:
return self.lintrans
lintrans = self.weight.view(self.out_channels, -1)
lintrans = lintrans.mm(lintrans.t())
lintrans = self.kappa(lintrans)
lintrans = ops.matrix_inverse_sqrt(lintrans)
if not self.training:
self._need_lintrans_computed = False
self.lintrans.data = lintrans.data
return lintrans
def _conv_layer(self, x_in):
"""Convolution layer
Compute x_out = ||x_in|| x kappa(Zt x_in/||x_in||)
Args:
x_in: batch_size x in_channels x H
self.filters: out_channels x in_channels x filter_size
x_out: batch_size x out_channels x (H - filter_size + 1)
"""
patch_norm = torch.sqrt(F.conv1d(x_in.pow(2), self.ones,
padding=self.padding, dilation=self.dilation,
groups=self.groups).clamp(min=EPS))
# patch_norm = patch_norm.clamp(EPS)
x_out = super(CKNLayer, self).forward(x_in)
x_out = x_out / patch_norm.clamp(min=EPS)
x_out = self.kappa(x_out)
x_out = patch_norm * x_out
return x_out
def _mult_layer(self, x_in, lintrans):
"""Multiplication layer
Compute x_out = kappa(ZtZ)^(-1/2) x x_in
Args:
x_in: batch_size x out_channels x H
lintrans: out_channels x out_channels
x_out: batch_size x out_channels x H
"""
batch_size, out_c, _ = x_in.size()
if x_in.dim() == 2:
return torch.mm(x_in, lintrans)
return torch.bmm(lintrans.expand(batch_size, out_c, out_c), x_in)
def _pool_layer(self, x_in):
"""Pooling layer
Compute I(z) = \sum_{z'} phi(z') x exp(-\beta_1 ||z'-z||_2^2)
"""
if self.subsampling <= 1:
return x_in
x_out = F.conv1d(x_in, self.pooling_filter, stride=self.subsampling,
padding=self.subsampling-1, groups=self.out_channels)
return x_out
def forward(self, x_in):
"""Encode function for a CKN layer
Args:
x_in: batch_size x in_channels x H x W
"""
x_out = self._conv_layer(x_in)
x_out = self._pool_layer(x_out)
lintrans = self._compute_lintrans()
x_out = self._mult_layer(x_out, lintrans)
return x_out
def compute_mask(self, mask=None):
if mask is None:
return mask
mask = mask.float().unsqueeze(1)
mask = F.avg_pool1d(mask, kernel_size=self.filter_size,
stride=self.subsampling)
mask = mask.squeeze(1) != 0
return mask
def extract_1d_patches(self, input, mask=None):
output = input.unfold(-1, self.filter_size, 1).transpose(1, 2)
output = output.contiguous().view(-1, self.patch_dim)
if mask is not None:
mask = mask.float().unsqueeze(1)
mask = F.avg_pool1d(mask, kernel_size=self.filter_size, stride=1)
# option 2: mask = mask.view(-1) == 1./self.filter_size
mask = mask.view(-1) != 0
output = output[mask]
return output
def sample_patches(self, x_in, mask=None, n_sampling_patches=1000):
"""Sample patches from the given Tensor
Args:
x_in (Tensor batch_size x in_channels x H)
n_sampling_patches (int): number of patches to sample
Returns:
patches: (batch_size x (H - filter_size + 1)) x (in_channels x filter_size)
"""
patches = self.extract_1d_patches(x_in, mask)
n_sampling_patches = min(patches.size(0), n_sampling_patches)
indices = torch.randperm(patches.size(0))[:n_sampling_patches]
patches = patches[indices]
normalize_(patches)
return patches
def unsup_train(self, patches, init=None):
"""Unsupervised training for a CKN layer
Args:
patches: n x in_channels x H
Updates:
filters: out_channels x in_channels x filter_size
"""
print(patches.shape)
weight = spherical_kmeans(patches, self.out_channels, init=init)
weight = weight.view_as(self.weight)
self.weight.data = weight.data
self._need_lintrans_computed = True
def normalize_(self):
norm = self.weight.data.view(
self.out_channels, -1).norm(p=2, dim=-1).view(-1, 1, 1)
norm.clamp_(min=EPS)
self.weight.data.div_(norm)
class BioEmbedding(nn.Module):
def __init__(self, num_embeddings, reverse_complement=False,
mask_zeros=False, no_embed=False, encoding='one_hot'):
"""Embedding layer for biosequences
Args:
num_embeddings (int): number of letters in alphabet
reverse_complement (boolean): reverse complement embedding or not
"""
super(BioEmbedding, self).__init__()
self.num_embeddings = num_embeddings
self.reverse_complement = reverse_complement
self.mask_zeros = mask_zeros
self.no_embed = no_embed
if no_embed:
return
self.embedding = lambda x, weight: F.embedding(x, weight)
if encoding == 'blosum62':
weight = torch.from_numpy(BLOSUM62.astype(np.float32))
else:
weight = self._make_weight(False)
self.register_buffer("weight", weight)
if reverse_complement:
weight_rc = self._make_weight(True)
self.register_buffer("weight_rc", weight_rc)
def _make_weight(self, reverse_complement=False):
if reverse_complement:
weight = np.zeros((self.num_embeddings + 1, self.num_embeddings),
dtype=np.float32)
weight[0] = 1./self.num_embeddings
weight[1:] = np.fliplr(np.diag(np.ones(self.num_embeddings)))
weight = torch.from_numpy(weight)
else:
weight = torch.zeros(self.num_embeddings + 1, self.num_embeddings)
weight[0] = 1./self.num_embeddings
weight[1:] = torch.diag(torch.ones(self.num_embeddings))
return weight
def compute_mask(self, x):
"""Compute the mask for the given Tensor
"""
if self.no_embed:
if self.mask_zeros:
s = x.norm(dim=1)
mask = (s != 0)
else:
mask = None
return mask
if self.mask_zeros:
mask = (x != 0)
if self.reverse_complement:
mask_rc = flip(mask, dim=-1)
mask = torch.cat((mask, mask_rc))
return mask
return None
def forward(self, x):
"""
Args:
x: LongTensor of indices
"""
if self.no_embed:
return x
x_out = self.embedding(x, self.weight)
if self.reverse_complement:
# reverse sequence
x = flip(x, dim=-1)
x_out_rc = self.embedding(x, self.weight_rc)
x_out = torch.cat((x_out, x_out_rc), dim=0)
return x_out.transpose(1, 2).contiguous()
class GlobalAvg1D(nn.Module):
def __init__(self):
super(GlobalAvg1D, self).__init__()
def forward(self, x, mask=None):
if mask is None:
return x.mean(dim=-1)
mask = mask.float().unsqueeze(1)
x = x * mask
return x.sum(dim=-1)/mask.sum(dim=-1)
class GlobalMax1D(nn.Module):
def __init__(self):
super(GlobalMax1D, self).__init__()
def forward(self, x, mask=None):
if mask is not None:
mask = mask.unsqueeze(1).expand_as(x)
mask = mask.data
x[~mask] = -float("inf")
return x.max(dim=-1)[0]
class GMP(nn.Module):
def __init__(self, alpha=1e-03):
super(GMP, self).__init__()
self.alpha = alpha
def forward(self, x, mask=None):
if mask is not None:
mask = mask.float().unsqueeze(1)
x = x * mask
xxt = torch.bmm(x, x.transpose(1, 2))
xxt.diagonal(dim1=1, dim2=2)[:] += self.alpha
eye = xxt.new_ones(xxt.size(-1)).diag().expand_as(xxt)
# xxt = torch.inverse(xxt)
xxt, _ = torch.gesv(eye, xxt)
x = torch.bmm(xxt, x)
return x.mean(dim=-1) #/ mask.sum(dim=-1)
POOLINGS = {'mean': GlobalAvg1D, 'max': GlobalMax1D, 'gmp': GMP}
class Preprocessor(nn.Module):
def __init__(self):
super(Preprocessor, self).__init__()
self.fitted = True
def forward(self, input):
out = input - input.mean(dim=1, keepdim=True)
return out / out.norm(dim=1, keepdim=True)
def fit(self, input):
pass
def fit_transform(self, input):
self.fit(input)
return self(input)
class RowPreprocessor(nn.Module):
def __init__(self):
super(RowPreprocessor, self).__init__()
self.register_buffer("mean", None)
self.register_buffer("var", None)
self.register_buffer("scale", None)
self.count = 0
self.fitted = False
def reset(self):
self.mean = None
self.var = None
self.scale = None
self.count = 0.
self.fitted = False
def forward(self, input):
if not self.fitted:
# self.partial_fit(input)
return input
input -= self.mean
input /= self.scale
return input
def fit(self, input):
self.mean = input.mean(dim=0)
self.var = input.var(dim=0, unbiased=False)
self.scale = self.var.sqrt()
def fit_transform(self, input):
self.fit(input)
return self(input)
def partial_fit(self, input):
if self.count == 0.:
self.mean = input.mean(0)
self.var = input.var(0, unbiased=False)
self.scale = self.var.sqrt()
self.count += input.shape[0]
else:
last_sum = self.count * self.mean
new_sum = input.sum(0)
updated_count = self.count + input.shape[0]
self.mean = (last_sum + new_sum) / updated_count
new_unnorm_var = input.var(0, unbiased=False) * input.shape[0]
last_unnorm_var = self.var * self.count
last_over_new_count = self.count / input.shape[0]
self.var = (
new_unnorm_var + last_unnorm_var +
last_over_new_count / updated_count *
(last_sum / last_over_new_count - new_sum) ** 2) / updated_count
self.count = updated_count
self.scale = self.var.sqrt()
def _load_from_state_dict(self, state_dict, prefix, metadata,
strict, missing_keys, unexpected_keys, error_msgs):
self.fitted = True
for k, v in self._buffers.items():
key = prefix + k
setattr(self, k, state_dict[key])
super(RowPreprocessor, self)._load_from_state_dict(
state_dict, prefix, metadata,
strict, missing_keys, unexpected_keys, error_msgs)
PREPROCESSORS = {'standard_col': Preprocessor, 'standard_row': RowPreprocessor}
class LinearMax(nn.Linear, LinearModel, LinearClassifierMixin):
def __init__(self, in_features, out_features, alpha=0.0, fit_bias=True,
reverse_complement=False, penalty="l2"):
super(LinearMax, self).__init__(in_features, out_features, fit_bias)
self.alpha = alpha
self.reverse_complement = reverse_complement
self.fit_bias = fit_bias
self.penalty = penalty
def forward(self, input, proba=False):
bias = self.bias
if bias is not None and hasattr(self, 'scale_bias') and self.scale_bias is not None:
bias = self.scale_bias * bias
out = F.linear(input, self.weight, bias)
if self.reverse_complement:
n_samples = out.size(0)//2
out = torch.max(out[:n_samples], out[n_samples:])
if proba:
return out.sigmoid()
return out
def fit(self, x, y, criterion=None):
use_cuda = self.weight.data.is_cuda
if criterion is None:
criterion = nn.BCEWithLogitsLoss()
reduction = criterion.reduction
criterion.reduction = 'sum'
if isinstance(x, np.ndarray) or isinstance(y, np.ndarray):
x = torch.from_numpy(x)
y = torch.from_numpy(y)
if self.reverse_complement:
n_samples, n_features = x.shape
n_features = n_features // 2
x = torch.cat([x[:, :n_features], x[:, n_features:]])
if use_cuda:
x = x.cuda()
y = y.cuda()
if self.bias is not None:
scale_bias = (x ** 2).mean(-1).sqrt().mean().item()
self.scale_bias = scale_bias
def eval_loss(w):
w = w.reshape((self.out_features, -1))
if self.weight.grad is not None:
self.weight.grad = None
if self.bias is None:
self.weight.data.copy_(torch.from_numpy(w))
else:
if self.bias.grad is not None:
self.bias.grad = None
self.weight.data.copy_(torch.from_numpy(w[:, :-1]))
self.bias.data.copy_(torch.from_numpy(w[:, -1]))
y_pred = self(x).squeeze_(-1)
loss = criterion(y_pred, y)
loss.backward()
if self.alpha != 0.0:
if self.penalty == "l2":
penalty = 0.5 * self.alpha * torch.norm(self.weight)**2
elif self.penalty == "l1":
penalty = self.alpha * torch.norm(self.weight, p=1)
penalty.backward()
loss = loss + penalty
return loss.item()
def eval_grad(w):
dw = self.weight.grad.data
if self.alpha != 0.0:
if self.penalty == "l2":
dw.add_(self.alpha, self.weight.data)
if self.bias is not None:
db = self.bias.grad.data
dw = torch.cat((dw, db.view(-1, 1)), dim=1)
return dw.cpu().numpy().ravel().astype("float64")
w_init = self.weight.data
if self.bias is not None:
w_init = torch.cat((w_init, self.bias.data.view(-1, 1)), dim=1)
w_init = w_init.cpu().numpy().astype("float64")
w = optimize.fmin_l_bfgs_b(
eval_loss, w_init, fprime=eval_grad, maxiter=100, disp=0)
if isinstance(w, tuple):
w = w[0]
w = w.reshape((self.out_features, -1))
self.weight.grad.data.zero_()
if self.bias is None:
self.weight.data.copy_(torch.from_numpy(w))
else:
self.bias.grad.data.zero_()
self.weight.data.copy_(torch.from_numpy(w[:, :-1]))
self.bias.data.copy_(torch.from_numpy(w[:, -1]))
criterion.reduction = reduction
if self.bias is not None:
self.bias.data.mul_(self.scale_bias)
self.scale_bias = None
def decision_function(self, x):
x = torch.from_numpy(x)
if self.reverse_complement:
n_samples, n_features = x.shape
n_features = n_features // 2
x = torch.cat([x[:, :n_features], x[:, n_features:]])
return self(x).data.numpy().ravel()
def predict(self, x):
return self.decision_function(x)
def predict_proba(self, x):
return self._predict_proba_lr(x)
@property
def coef_(self):
return self.weight.data.numpy()
@property
def intercept_(self):
return self.bias.data.numpy()