Diff of /model.py [000000] .. [597177]

Switch to unified view

a b/model.py
1
import torch
2
import torch.nn as nn
3
4
from timm.models.layers import PatchEmbed, Mlp, DropPath
5
from timm.models.registry import register_model
6
7
class Attention(nn.Module):
8
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
9
        super().__init__()
10
        self.num_heads = num_heads
11
        head_dim = dim // num_heads
12
        self.scale = qk_scale or head_dim ** -0.5
13
14
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
15
        self.attn_drop = nn.Dropout(attn_drop)
16
        self.proj = nn.Linear(dim, dim)
17
        self.proj_drop = nn.Dropout(proj_drop)
18
19
    def forward(self, x):
20
        B, N, C = x.shape
21
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
22
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
23
24
        attn = (q @ k.transpose(-2, -1)) * self.scale
25
        attn = attn.softmax(dim=-1)
26
        attn = self.attn_drop(attn)
27
28
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
29
        x = self.proj(x)
30
        x = self.proj_drop(x)
31
        return x
32
33
class Block(nn.Module):
34
    
35
    def __init__(self, dim, num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
36
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
37
        super().__init__()
38
        self.norm1 = norm_layer(dim)
39
        self.attn = Attention(
40
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
41
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
42
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
43
        self.norm2 = norm_layer(dim)
44
        mlp_hidden_dim = int(dim * mlp_ratio)
45
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
46
47
    def forward(self, x):
48
        x = x + self.drop_path(self.attn(self.norm1(x)))
49
        x = x + self.drop_path(self.mlp(self.norm2(x)))
50
        return x
51
52
class EEGTransformer(nn.Module):
53
    def __init__(self):
54
        super().__init__()
55
        self.dim = 4000
56
        self.decoder_depth = 1
57
        self.blocks = nn.Sequential(*[
58
            Block(
59
                dim= self.dim,
60
                drop= 0.2,
61
                attn_drop= 0.2,
62
            )
63
            for i in range(self.decoder_depth)])
64
        self.relu = nn.ReLU()
65
        # self.softmax = nn.Softmax()
66
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.dim)) # 1, 1, 8000
67
68
        # self.fc1 = nn.Linear(60, 512)
69
        self.fc1 = nn.Linear(self.dim, 512)
70
71
        self.fc2 = nn.Linear(512, 3)
72
73
        
74
        torch.nn.init.normal_(self.cls_token, std=.02)
75
        
76
    def forward(self, x):
77
        # x -> bz x 59 x 8000
78
        # print(x.shape)
79
        cls_token = self.cls_token.repeat(x.shape[0], 1, 1) # bz x 1 x 8000
80
        x = torch.cat((x, cls_token), 1) # bz x 60 x 8000
81
        x = self.blocks(x)
82
83
        # cls_token = x[:, -1, :].squeeze(1).unsqueeze(-1) # bz x 8000 x 1
84
        # attn_map = torch.bmm(x, cls_token).squeeze(-1) # bz x 60
85
        # attn_map = self.relu(self.fc1(attn_map))
86
        # attn_map = self.fc2(attn_map)
87
        # return attn_map
88
89
        cls_token = x[:, -1, :].squeeze(1)
90
        cls_token = self.relu(self.fc1(cls_token))
91
        cls_token = self.fc2(cls_token)
92
        return cls_token
93
    
94
@register_model
95
def eegt(**kwargs):
96
    model = EEGTransformer()
97
    return model
98