|
a |
|
b/utils.py |
|
|
1 |
import requests |
|
|
2 |
import zipfile |
|
|
3 |
|
|
|
4 |
import numpy as np |
|
|
5 |
import matplotlib.pyplot as plt |
|
|
6 |
|
|
|
7 |
def get_data(): |
|
|
8 |
print('Downloading started') |
|
|
9 |
url = 'http://bbci.de/competition/download/competition_iv/BCICIV_1calib_1000Hz_mat.zip' |
|
|
10 |
|
|
|
11 |
username = 'replace_with_your_own_username' |
|
|
12 |
password = 'replace_with_your_own_password' |
|
|
13 |
req = requests.get(url, auth=(username,password)) |
|
|
14 |
filename = url.split('/')[-1] |
|
|
15 |
|
|
|
16 |
with open(filename,'wb') as output_file: |
|
|
17 |
output_file.write(req.content) |
|
|
18 |
print('Downloading Completed') |
|
|
19 |
|
|
|
20 |
# Change to your path |
|
|
21 |
print('Unzipping') |
|
|
22 |
path_to_zip_file = 'BCICIV_1calib_1000Hz_mat.zip' |
|
|
23 |
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref: |
|
|
24 |
zip_ref.extractall('data/') |
|
|
25 |
print('Unzipping Completed') |
|
|
26 |
return |
|
|
27 |
|
|
|
28 |
def parse_log(fp): |
|
|
29 |
log = { |
|
|
30 |
'log_info': fp, |
|
|
31 |
'train_loss': [], |
|
|
32 |
'val_loss': [], |
|
|
33 |
'train_acc':[], |
|
|
34 |
'val_acc': [] |
|
|
35 |
} |
|
|
36 |
f = open(fp, 'r') |
|
|
37 |
lines = f.readlines() |
|
|
38 |
for line in lines: |
|
|
39 |
if 'train Loss' in line: |
|
|
40 |
line = line.split() |
|
|
41 |
log['train_loss'].append(float(line[2])) |
|
|
42 |
log['train_acc'].append(float(line[4])) |
|
|
43 |
elif 'val Loss' in line: |
|
|
44 |
line = line.split() |
|
|
45 |
log['val_loss'].append(float(line[2])) |
|
|
46 |
log['val_acc'].append(float(line[4])) |
|
|
47 |
return log |
|
|
48 |
|
|
|
49 |
def plot_log(fp): |
|
|
50 |
log = parse_log(fp) |
|
|
51 |
f, ax = plt.subplots(1, 2, figsize=(15,5)) |
|
|
52 |
ax[0].plot(log['train_acc'], label='Training Accuracy', linestyle='dashed') |
|
|
53 |
ax[0].plot(log['val_acc'], label='Validation Accuracy', linestyle='dashed') |
|
|
54 |
ax[0].legend() |
|
|
55 |
ax[1].plot(log['train_loss'], label='Training Loss', linestyle='dashed') |
|
|
56 |
ax[1].plot(log['val_loss'], label='Validation Loss', linestyle='dashed') |
|
|
57 |
ax[1].legend() |
|
|
58 |
f.suptitle(f'Experiment Log: {fp}') |