|
a |
|
b/ecg_classification/config.py |
|
|
1 |
import random |
|
|
2 |
import numpy as np |
|
|
3 |
import torch |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
class Config: |
|
|
7 |
csv_path = '' |
|
|
8 |
seed = 2021 |
|
|
9 |
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
|
|
10 |
attn_state_path = '../input/mitbih-with-synthetic/attn.pth' |
|
|
11 |
lstm_state_path = '../input/mitbih-with-synthetic/lstm.pth' |
|
|
12 |
cnn_state_path = '../input/mitbih-with-synthetic/cnn.pth' |
|
|
13 |
|
|
|
14 |
attn_logs = '../input/mitbih-with-synthetic/attn.csv' |
|
|
15 |
lstm_logs = '../input/mitbih-with-synthetic/lstm.csv' |
|
|
16 |
cnn_logs = '../input/mitbih-with-synthetic/cnn.csv' |
|
|
17 |
|
|
|
18 |
train_csv_path = '../input/mitbih-with-synthetic/mitbih_with_syntetic_train.csv' |
|
|
19 |
test_csv_path = '../input/mitbih-with-synthetic/mitbih_with_syntetic_test.csv' |
|
|
20 |
|
|
|
21 |
def seed_everything(seed: int): |
|
|
22 |
random.seed(seed) |
|
|
23 |
np.random.seed(seed) |
|
|
24 |
torch.manual_seed(seed) |
|
|
25 |
if torch.cuda.is_available(): |
|
|
26 |
torch.cuda.manual_seed(seed) |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
if __name__ == '__main__': |
|
|
30 |
config = Config() |
|
|
31 |
seed_everything(config.seed) |