|
a |
|
b/ipynb/KlsAutoencoder.py |
|
|
1 |
|
|
|
2 |
from fastai import * |
|
|
3 |
from fastai.basic_data import * |
|
|
4 |
from fastai.basic_train import * |
|
|
5 |
from fastai.tabular import * |
|
|
6 |
from torch import nn |
|
|
7 |
import torch.nn.functional as F |
|
|
8 |
from torch.utils.data import Dataset |
|
|
9 |
|
|
|
10 |
class KlsDataset(Dataset): |
|
|
11 |
"""Kmer latent representation dataset""" |
|
|
12 |
def __init__(self, data,noise=0.): |
|
|
13 |
super().__init__() |
|
|
14 |
self.items = data.values if isinstance(data, pd.DataFrame) else data |
|
|
15 |
self.noise = noise |
|
|
16 |
|
|
|
17 |
def __len__(self): |
|
|
18 |
return len(self.items) |
|
|
19 |
def __getitem__(self, idx): |
|
|
20 |
item = self.items[idx,:] |
|
|
21 |
return (item if self.noise == 0. else self.mix_noise(item), item) |
|
|
22 |
|
|
|
23 |
def wing(dims): |
|
|
24 |
layer_dims = list(zip(dims[:-1],dims[1:])) |
|
|
25 |
fcl = [nn.Linear(*x, bias=False) for x in layer_dims] |
|
|
26 |
relu = [nn.ReLU() for _ in range(len(fcl))] |
|
|
27 |
layers = np.asarray(list(zip(fcl, relu))).ravel()[:-1] |
|
|
28 |
return nn.Sequential(*layers) |
|
|
29 |
|
|
|
30 |
def init_weights(m): |
|
|
31 |
if type(m) == nn.Linear: |
|
|
32 |
nn.init.xavier_uniform_(m.weight) |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
def print_weights(nlayer): |
|
|
36 |
print(list(net.encoder.modules())[nlayer].weight) |
|
|
37 |
|
|
|
38 |
class KlsAutoEncoder (nn.Module): |
|
|
39 |
"""Generic autoencoder""" |
|
|
40 |
def __init__(self, encoder_dims, decoder_dims): |
|
|
41 |
super().__init__(self) |
|
|
42 |
self.encoder = wing(encoder_dims) |
|
|
43 |
self.decoder = wing(decoder_dims) |
|
|
44 |
|
|
|
45 |
def forward(self, x): |
|
|
46 |
x = self.encoder(x) |
|
|
47 |
return self.decoder(x) |
|
|
48 |
|
|
|
49 |
def save_encoder(self,file:PathOrStr): |
|
|
50 |
torch.save(self.encoder.state_dict(), path) |
|
|
51 |
|
|
|
52 |
class Encoder(): |
|
|
53 |
"""Encoder part of KlsAutoeEncoder ready for inference""" |
|
|
54 |
|
|
|
55 |
def __init__(self,file:PathOrStr,dims:Collection=[100,50,3]): |
|
|
56 |
e = wing(dims).double() |
|
|
57 |
e.load_state_dict(torch.load(file)) |
|
|
58 |
e.eval() |
|
|
59 |
self.e = e |
|
|
60 |
|
|
|
61 |
def transform(self,data:Collection): |
|
|
62 |
"""transform ```data``` to latent representaion""" |
|
|
63 |
return self.e.forward(tensor(data).double()).cpu().detach().numpy() |