# -*- coding: utf-8 -*-
import torch
from torch import nn
from .layers import OTKernel, Linear
from ..ckn.layers import BioEmbedding
from ..ckn.models import CKNSequential
import pdb
class SeqAttention(nn.Module):
def __init__(self, in_channels, nclass, hidden_sizes, filter_sizes,
subsamplings, kernel_args=None, eps=0.1, heads=1,
out_size=1, max_iter=50, alpha=0., fit_bias=True,
mask_zeros=True, position_encoding=None, image=True):
super().__init__()
self.embed_layer = BioEmbedding(
in_channels, False, mask_zeros=True, no_embed=True)
self.ckn_model = CKNSequential(
in_channels, hidden_sizes, filter_sizes,
subsamplings, kernel_args_list=kernel_args)
self.attention = OTKernel(hidden_sizes[-1], out_size, heads=heads,
eps=eps, max_iter=max_iter, position_encoding=position_encoding, image=image)
self.position_encoding = position_encoding
self.out_features = out_size * heads * hidden_sizes[-1]
self.nclass = nclass
self.classifier = Linear(self.out_features, nclass, bias=fit_bias)
self.alpha = alpha
self.mask_zeros = mask_zeros
def feature_parameters(self):
import itertools
return itertools.chain(self.ckn_model.parameters(),
self.attention.parameters())
def normalize_(self):
self.ckn_model.normalize_()
def ckn_representation_at(self, input, n=0):
output = self.embed_layer(input)
mask = self.embed_layer.compute_mask(input)
output = self.ckn_model.representation(output, n)
mask = self.ckn_model.compute_mask(mask, n)
return output, mask
def ckn_representation(self, input):
output = self.embed_layer(input)
output = self.ckn_model(output).permute(0, 2, 1).contiguous()
return output
def representation(self, input, coords=None):
output = self.embed_layer(input)
mask = self.embed_layer.compute_mask(input)
output = self.ckn_model(output).permute(0, 2, 1).contiguous()
mask = self.ckn_model.compute_mask(mask)
if not self.mask_zeros:
mask = None
output = self.attention(output, mask, coords=coords).reshape(output.shape[0], -1)
return output
def forward(self, input):
output = self.representation(input)
return self.classifier(output)
def predict(self, data_loader, only_repr=False, use_cuda=False):
n_samples = len(data_loader.dataset)
target_output = torch.LongTensor(n_samples)
batch_start = 0
for i, batch in enumerate(data_loader):
data = batch['img'].permute(0, 2, 1)
target = batch['label']
batch_size = data.shape[0]
if use_cuda:
data = data.cuda()
with torch.no_grad():
if only_repr:
batch_out = self.representation(data).data.cpu()
else:
batch_out = self(data).data.cpu()
if i == 0:
output = batch_out.new_empty([n_samples] +
list(batch_out.shape[1:]))
output[batch_start:batch_start + batch_size] = batch_out
target_output[batch_start:batch_start + batch_size] = target
batch_start += batch_size
return output, target_output
def train_classifier(self, data_loader, criterion=None, epochs=100,
optimizer=None, use_cuda=False):
encoded_train, encoded_target = self.predict(
data_loader, only_repr=True, use_cuda=use_cuda)
self.classifier.fit(encoded_train, encoded_target, criterion,
reg=self.alpha, epochs=epochs, optimizer=optimizer,
use_cuda=use_cuda)
def unsup_train(self, data_loader, n_sampling_patches=300000,
n_samples=5000, wb=False, use_cuda=False):
self.eval()
if use_cuda:
self.cuda()
for i, ckn_layer in enumerate(self.ckn_model):
print("Training ckn layer {}".format(i))
n_patches = 0
try:
n_patches_per_batch = (
n_sampling_patches + len(data_loader) - 1
) // len(data_loader)
except:
n_patches_per_batch = 1000
patches = torch.Tensor(n_sampling_patches, ckn_layer.patch_dim)
if use_cuda:
patches = patches.cuda()
for data, coords, _ in data_loader:
if n_patches >= n_sampling_patches:
continue
if use_cuda:
data = data.cuda()
with torch.no_grad():
data, mask = self.ckn_representation_at(data, i)
data_patches = ckn_layer.sample_patches(
data, mask, n_patches_per_batch)
size = data_patches.size(0)
if n_patches + size > n_sampling_patches:
size = n_sampling_patches - n_patches
data_patches = data_patches[:size]
patches[n_patches: n_patches + size] = data_patches
n_patches += size
print("total number of patches: {}".format(n_patches))
patches = patches[:n_patches]
ckn_layer.unsup_train(patches, init=None)
n_samples = min(n_samples, len(data_loader.dataset))
cur_samples = 0
print("Training attention layer")
for i, (data, coords, _) in enumerate(data_loader):
if cur_samples >= n_samples:
continue
if use_cuda:
data = data.cuda()
# if self.position_encoding is not None:
# coords = coords.cuda()
pdb.set_trace()
with torch.no_grad():
data = self.ckn_representation(data)
if i == 0:
patches = torch.empty([n_samples]+list(data.shape[1:]))
# if self.position_encoding is not None:
# patch_coords = torch.empty([n_samples]+list(coords.shape[1:]))
size = data.shape[0]
if cur_samples + size > n_samples:
size = n_samples - cur_samples
data = data[:size]
# if self.position_encoding is not None:
# coords = coords[:size]
patches[cur_samples: cur_samples + size] = data
# if self.position_encoding is not None:
# patch_coords[cur_samples: cur_samples + size] = coords
cur_samples += size
print(patches.shape)
self.attention.unsup_train(patches, wb=wb, use_cuda=use_cuda)