a b/app/models/backbones/tcn.py
1
import argparse
2
import copy
3
import datetime
4
import math
5
import os
6
import pickle
7
import random
8
import re
9
10
import numpy as np
11
import torch
12
import torch.nn.functional as F
13
import torch.nn.utils.rnn as rnn_utils
14
from sklearn.model_selection import KFold, StratifiedKFold
15
from torch import nn
16
from torch.autograd import Variable
17
from torch.nn.utils import weight_norm
18
from torch.utils import data
19
from torch.utils.data import (
20
    ConcatDataset,
21
    DataLoader,
22
    Dataset,
23
    Subset,
24
    SubsetRandomSampler,
25
    TensorDataset,
26
    random_split,
27
)
28
29
30
# From TCN original paper https://github.com/locuslab/TCN
31
class Chomp1d(nn.Module):
32
    def __init__(self, chomp_size):
33
        super(Chomp1d, self).__init__()
34
        self.chomp_size = chomp_size
35
36
    def forward(self, x):
37
        return x[:, :, : -self.chomp_size].contiguous()
38
39
40
class TemporalBlock(nn.Module):
41
    def __init__(
42
        self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2
43
    ):
44
        super(TemporalBlock, self).__init__()
45
        self.conv1 = weight_norm(
46
            nn.Conv1d(
47
                n_inputs,
48
                n_outputs,
49
                kernel_size,
50
                stride=stride,
51
                padding=padding,
52
                dilation=dilation,
53
            ),
54
            dim=None,
55
        )
56
        self.chomp1 = Chomp1d(padding)
57
        self.relu1 = nn.ReLU()
58
        self.dropout1 = nn.Dropout(dropout)
59
60
        self.conv2 = weight_norm(
61
            nn.Conv1d(
62
                n_outputs,
63
                n_outputs,
64
                kernel_size,
65
                stride=stride,
66
                padding=padding,
67
                dilation=dilation,
68
            ),
69
            dim=None,
70
        )
71
        self.chomp2 = Chomp1d(padding)
72
        self.relu2 = nn.ReLU()
73
        self.dropout2 = nn.Dropout(dropout)
74
75
        self.net = nn.Sequential(
76
            self.conv1,
77
            self.chomp1,
78
            self.relu1,
79
            self.dropout1,
80
            self.conv2,
81
            self.chomp2,
82
            self.relu2,
83
            self.dropout2,
84
        )
85
        self.downsample = (
86
            nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
87
        )
88
        self.relu = nn.ReLU()
89
        self.init_weights()
90
91
    def init_weights(self):
92
        self.conv1.weight.data.normal_(0, 0.01)
93
        self.conv2.weight.data.normal_(0, 0.01)
94
        if self.downsample is not None:
95
            self.downsample.weight.data.normal_(0, 0.01)
96
97
    def forward(self, x):
98
        out = self.net(x)
99
        res = x if self.downsample is None else self.downsample(x)
100
        return self.relu(out + res)
101
102
103
# From TCN original paper https://github.com/locuslab/TCN
104
class TemporalConvNet(nn.Module):
105
    def __init__(
106
        self,
107
        num_inputs,
108
        num_channels,  # serve as hidden dim
109
        max_seq_length=0,
110
        kernel_size=2,
111
        dropout=0.0,
112
    ):
113
        super(TemporalConvNet, self).__init__()
114
        self.num_channels = num_channels
115
116
        layers = []
117
118
        # We compute automatically the depth based on the desired seq_length.
119
        if isinstance(num_channels, int) and max_seq_length:
120
            num_channels = [num_channels] * int(
121
                np.ceil(np.log(max_seq_length / 2) / np.log(kernel_size))
122
            )
123
        elif isinstance(num_channels, int) and not max_seq_length:
124
            raise Exception(
125
                "a maximum sequence length needs to be provided if num_channels is int"
126
            )
127
128
        num_levels = len(num_channels)
129
        for i in range(num_levels):
130
            dilation_size = 2 ** i
131
            in_channels = num_inputs if i == 0 else num_channels[i - 1]
132
            out_channels = num_channels[i]
133
            layers += [
134
                TemporalBlock(
135
                    in_channels,
136
                    out_channels,
137
                    kernel_size,
138
                    stride=1,
139
                    dilation=dilation_size,
140
                    padding=(kernel_size - 1) * dilation_size,
141
                    dropout=dropout,
142
                )
143
            ]
144
145
        self.network = nn.Sequential(*layers)
146
147
    def forward(self, x, device, info=None):
148
        """extra info is not used here"""
149
        batch_size, time_steps, _ = x.size()
150
        out = torch.zeros((batch_size, time_steps, self.num_channels)).to(device)
151
        for cur_time in range(time_steps):
152
            cur_x = x[:, : cur_time + 1, :]
153
            cur_x = cur_x.permute(0, 2, 1)  # Permute to channel first
154
            o = self.network(cur_x)
155
            o = o.permute(0, 2, 1)  # Permute to channel last
156
            out[:, cur_time, :] = torch.mean(o, dim=1)
157
        return out