|
a |
|
b/model.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
|
|
|
4 |
from timm.models.layers import PatchEmbed, Mlp, DropPath |
|
|
5 |
from timm.models.registry import register_model |
|
|
6 |
|
|
|
7 |
class Attention(nn.Module): |
|
|
8 |
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): |
|
|
9 |
super().__init__() |
|
|
10 |
self.num_heads = num_heads |
|
|
11 |
head_dim = dim // num_heads |
|
|
12 |
self.scale = qk_scale or head_dim ** -0.5 |
|
|
13 |
|
|
|
14 |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
|
15 |
self.attn_drop = nn.Dropout(attn_drop) |
|
|
16 |
self.proj = nn.Linear(dim, dim) |
|
|
17 |
self.proj_drop = nn.Dropout(proj_drop) |
|
|
18 |
|
|
|
19 |
def forward(self, x): |
|
|
20 |
B, N, C = x.shape |
|
|
21 |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
|
|
22 |
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) |
|
|
23 |
|
|
|
24 |
attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
25 |
attn = attn.softmax(dim=-1) |
|
|
26 |
attn = self.attn_drop(attn) |
|
|
27 |
|
|
|
28 |
x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
|
|
29 |
x = self.proj(x) |
|
|
30 |
x = self.proj_drop(x) |
|
|
31 |
return x |
|
|
32 |
|
|
|
33 |
class Block(nn.Module): |
|
|
34 |
|
|
|
35 |
def __init__(self, dim, num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., |
|
|
36 |
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): |
|
|
37 |
super().__init__() |
|
|
38 |
self.norm1 = norm_layer(dim) |
|
|
39 |
self.attn = Attention( |
|
|
40 |
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) |
|
|
41 |
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here |
|
|
42 |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
43 |
self.norm2 = norm_layer(dim) |
|
|
44 |
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
45 |
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
|
|
46 |
|
|
|
47 |
def forward(self, x): |
|
|
48 |
x = x + self.drop_path(self.attn(self.norm1(x))) |
|
|
49 |
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
|
50 |
return x |
|
|
51 |
|
|
|
52 |
class EEGTransformer(nn.Module): |
|
|
53 |
def __init__(self): |
|
|
54 |
super().__init__() |
|
|
55 |
self.dim = 4000 |
|
|
56 |
self.decoder_depth = 1 |
|
|
57 |
self.blocks = nn.Sequential(*[ |
|
|
58 |
Block( |
|
|
59 |
dim= self.dim, |
|
|
60 |
drop= 0.2, |
|
|
61 |
attn_drop= 0.2, |
|
|
62 |
) |
|
|
63 |
for i in range(self.decoder_depth)]) |
|
|
64 |
self.relu = nn.ReLU() |
|
|
65 |
# self.softmax = nn.Softmax() |
|
|
66 |
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.dim)) # 1, 1, 8000 |
|
|
67 |
|
|
|
68 |
# self.fc1 = nn.Linear(60, 512) |
|
|
69 |
self.fc1 = nn.Linear(self.dim, 512) |
|
|
70 |
|
|
|
71 |
self.fc2 = nn.Linear(512, 3) |
|
|
72 |
|
|
|
73 |
|
|
|
74 |
torch.nn.init.normal_(self.cls_token, std=.02) |
|
|
75 |
|
|
|
76 |
def forward(self, x): |
|
|
77 |
# x -> bz x 59 x 8000 |
|
|
78 |
# print(x.shape) |
|
|
79 |
cls_token = self.cls_token.repeat(x.shape[0], 1, 1) # bz x 1 x 8000 |
|
|
80 |
x = torch.cat((x, cls_token), 1) # bz x 60 x 8000 |
|
|
81 |
x = self.blocks(x) |
|
|
82 |
|
|
|
83 |
# cls_token = x[:, -1, :].squeeze(1).unsqueeze(-1) # bz x 8000 x 1 |
|
|
84 |
# attn_map = torch.bmm(x, cls_token).squeeze(-1) # bz x 60 |
|
|
85 |
# attn_map = self.relu(self.fc1(attn_map)) |
|
|
86 |
# attn_map = self.fc2(attn_map) |
|
|
87 |
# return attn_map |
|
|
88 |
|
|
|
89 |
cls_token = x[:, -1, :].squeeze(1) |
|
|
90 |
cls_token = self.relu(self.fc1(cls_token)) |
|
|
91 |
cls_token = self.fc2(cls_token) |
|
|
92 |
return cls_token |
|
|
93 |
|
|
|
94 |
@register_model |
|
|
95 |
def eegt(**kwargs): |
|
|
96 |
model = EEGTransformer() |
|
|
97 |
return model |
|
|
98 |
|