|
a |
|
b/model/common.py |
|
|
1 |
""" |
|
|
2 |
conv1d与lstm结合, 输入数据为rri. |
|
|
3 |
""" |
|
|
4 |
|
|
|
5 |
import keras |
|
|
6 |
import numpy as np |
|
|
7 |
import time |
|
|
8 |
import matplotlib.pyplot as plt |
|
|
9 |
import json |
|
|
10 |
import os |
|
|
11 |
import random |
|
|
12 |
from sklearn.preprocessing import StandardScaler, MinMaxScaler |
|
|
13 |
|
|
|
14 |
from keras.utils import plot_model |
|
|
15 |
from keras.callbacks import * |
|
|
16 |
from sklearn.model_selection import train_test_split |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
def write_txt_file(list_info, write_file_path): |
|
|
20 |
""" |
|
|
21 |
Write list object to TXT file. |
|
|
22 |
:param list list_info: List object you want to write. |
|
|
23 |
:param string write_file_path: TXT file path. |
|
|
24 |
:return: None |
|
|
25 |
""" |
|
|
26 |
with open(write_file_path, "w") as f: |
|
|
27 |
for info in list_info: |
|
|
28 |
f.write(str(info) + "\n") |
|
|
29 |
|
|
|
30 |
|
|
|
31 |
def plot_fig(data, file_path="", title="", show_fig=False): |
|
|
32 |
if not os.path.exists(file_path) and file_path != "": |
|
|
33 |
os.makedirs(file_path) |
|
|
34 |
min, max = np.min(data, axis=0), np.max(data, axis=0) |
|
|
35 |
x = list(range(0, len(data), 1)) |
|
|
36 |
f, ax = plt.subplots() |
|
|
37 |
ax.plot(x, data) |
|
|
38 |
ax.set_ylim([np.floor(min), np.ceil(max)]) |
|
|
39 |
ax.set_title(title) |
|
|
40 |
if show_fig: |
|
|
41 |
plt.show() |
|
|
42 |
if title != "": |
|
|
43 |
plt.savefig(file_path + title + ".jpg") |
|
|
44 |
plt.close() |
|
|
45 |
|
|
|
46 |
|
|
|
47 |
class LossHistory(keras.callbacks.Callback): |
|
|
48 |
def init(self): |
|
|
49 |
self.losses = [] |
|
|
50 |
|
|
|
51 |
def on_epoch_end(self, batch, logs={}): |
|
|
52 |
self.losses.append(logs.get('loss')) |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
class TrainingMonitor(BaseLogger): |
|
|
56 |
""" |
|
|
57 |
https://blog.csdn.net/OliverkingLi/article/details/81214947 |
|
|
58 |
""" |
|
|
59 |
|
|
|
60 |
def __init__(self, fig_path, model, |
|
|
61 |
train_loss_path, test_loss_path, train_acc_path, test_acc_path, json_path=None, start_At=0): |
|
|
62 |
""" |
|
|
63 |
训练监控初始化 |
|
|
64 |
:param fig_path: loss store path |
|
|
65 |
:param model: |
|
|
66 |
:param json_path: Json file path |
|
|
67 |
:param int start_At: |
|
|
68 |
:return: None |
|
|
69 |
""" |
|
|
70 |
|
|
|
71 |
super(TrainingMonitor, self).__init__() |
|
|
72 |
self.fig_path = fig_path + "/xxx.png" |
|
|
73 |
self.json_path = json_path |
|
|
74 |
self.start_At = start_At |
|
|
75 |
self.model = model |
|
|
76 |
self.epochs = 0 |
|
|
77 |
|
|
|
78 |
self.train_loss_path = train_loss_path |
|
|
79 |
self.test_loss_path = test_loss_path |
|
|
80 |
self.train_acc_path = train_acc_path |
|
|
81 |
self.test_acc_path = test_acc_path |
|
|
82 |
|
|
|
83 |
def on_train_begin(self, logs={}): |
|
|
84 |
self.H = {} |
|
|
85 |
if self.json_path is not None: |
|
|
86 |
if os.path.exists(self.json_path): |
|
|
87 |
self.H = json.loads(open(self.json_path).read()) |
|
|
88 |
if self.start_At > 0: |
|
|
89 |
for k in self.H.keys(): |
|
|
90 |
self.H[k] = self.H[k][:self.start_At] |
|
|
91 |
|
|
|
92 |
def on_epoch_end(self, epoch, logs=None): |
|
|
93 |
for (k, v) in logs.items(): |
|
|
94 |
l = self.H.get(k, []) |
|
|
95 |
l.append(v) |
|
|
96 |
self.H[k] = l |
|
|
97 |
if self.json_path is not None: |
|
|
98 |
f = open(self.json_path, 'w') |
|
|
99 |
f.write(json.dumps(self.H)) |
|
|
100 |
f.close() |
|
|
101 |
if len(self.H["loss"]) > 1: |
|
|
102 |
N = np.arange(0, len(self.H["loss"])) |
|
|
103 |
plt.style.use("ggplot") |
|
|
104 |
plt.figure() |
|
|
105 |
plt.plot(N, self.H["loss"], label="train_loss") |
|
|
106 |
write_txt_file(self.H["loss"], self.train_loss_path) |
|
|
107 |
plt.plot(N, self.H["val_loss"], label="val_loss") |
|
|
108 |
write_txt_file(self.H["val_loss"], self.test_loss_path) |
|
|
109 |
plt.plot(N, self.H["acc"], label="train_acc") |
|
|
110 |
write_txt_file(self.H["acc"], self.train_acc_path) |
|
|
111 |
plt.plot(N, self.H["val_acc"], label="val_acc") |
|
|
112 |
write_txt_file(self.H["val_acc"], self.test_acc_path) |
|
|
113 |
plt.title("Training Loss and Accuracy [Epoch {}]".format(len(self.H["loss"]))) |
|
|
114 |
plt.xlabel("Epoch #") |
|
|
115 |
plt.ylabel("Loss/Accuracy") |
|
|
116 |
plt.legend() |
|
|
117 |
plt.savefig(self.fig_path) |
|
|
118 |
plt.close() |