a b/ipynb/KlsAutoencoder.py
1
2
from fastai import *
3
from fastai.basic_data import *
4
from fastai.basic_train import *
5
from fastai.tabular import *
6
from torch import nn
7
import torch.nn.functional as F
8
from torch.utils.data import Dataset
9
10
class KlsDataset(Dataset):
11
    """Kmer latent representation dataset"""
12
    def __init__(self, data,noise=0.):
13
        super().__init__()
14
        self.items = data.values if isinstance(data, pd.DataFrame) else data
15
        self.noise = noise
16
17
    def __len__(self):
18
        return len(self.items)
19
    def __getitem__(self, idx):
20
        item = self.items[idx,:]
21
        return (item if self.noise == 0. else self.mix_noise(item), item)
22
23
def wing(dims):
24
    layer_dims = list(zip(dims[:-1],dims[1:]))
25
    fcl = [nn.Linear(*x, bias=False) for x in layer_dims]
26
    relu = [nn.ReLU() for _ in range(len(fcl))]
27
    layers = np.asarray(list(zip(fcl, relu))).ravel()[:-1]
28
    return nn.Sequential(*layers)
29
30
def init_weights(m):
31
    if type(m) == nn.Linear:
32
        nn.init.xavier_uniform_(m.weight)
33
34
35
def print_weights(nlayer):
36
    print(list(net.encoder.modules())[nlayer].weight)
37
38
class KlsAutoEncoder (nn.Module):
39
    """Generic autoencoder"""
40
    def __init__(self, encoder_dims, decoder_dims):
41
        super().__init__(self)
42
        self.encoder = wing(encoder_dims)
43
        self.decoder = wing(decoder_dims)
44
45
    def forward(self, x):
46
        x = self.encoder(x)
47
        return self.decoder(x)
48
49
    def save_encoder(self,file:PathOrStr):
50
        torch.save(self.encoder.state_dict(), path)
51
52
class Encoder():
53
    """Encoder part of KlsAutoeEncoder ready for inference"""
54
55
    def __init__(self,file:PathOrStr,dims:Collection=[100,50,3]):
56
        e = wing(dims).double()
57
        e.load_state_dict(torch.load(file))
58
        e.eval()
59
        self.e = e
60
61
    def transform(self,data:Collection):
62
        """transform ```data``` to latent representaion"""
63
        return self.e.forward(tensor(data).double()).cpu().detach().numpy()