Diff of /model/common.py [000000] .. [409112]

Switch to unified view

a b/model/common.py
1
"""
2
    conv1d与lstm结合, 输入数据为rri.
3
"""
4
5
import keras
6
import numpy as np
7
import time
8
import matplotlib.pyplot as plt
9
import json
10
import os
11
import random
12
from sklearn.preprocessing import StandardScaler, MinMaxScaler
13
14
from keras.utils import plot_model
15
from keras.callbacks import *
16
from sklearn.model_selection import train_test_split
17
18
19
def write_txt_file(list_info, write_file_path):
20
    """
21
    Write list object to TXT file.
22
    :param list list_info: List object you want to write.
23
    :param string write_file_path: TXT file path.
24
    :return: None
25
    """
26
    with open(write_file_path, "w") as f:
27
        for info in list_info:
28
            f.write(str(info) + "\n")
29
30
31
def plot_fig(data, file_path="", title="", show_fig=False):
32
    if not os.path.exists(file_path) and file_path != "":
33
        os.makedirs(file_path)
34
    min, max = np.min(data, axis=0), np.max(data, axis=0)
35
    x = list(range(0, len(data), 1))
36
    f, ax = plt.subplots()
37
    ax.plot(x, data)
38
    ax.set_ylim([np.floor(min), np.ceil(max)])
39
    ax.set_title(title)
40
    if show_fig:
41
        plt.show()
42
    if title != "":
43
        plt.savefig(file_path + title + ".jpg")
44
    plt.close()
45
    
46
47
class LossHistory(keras.callbacks.Callback):
48
    def init(self):
49
        self.losses = []
50
    
51
    def on_epoch_end(self, batch, logs={}):
52
        self.losses.append(logs.get('loss'))
53
54
55
class TrainingMonitor(BaseLogger):
56
    """
57
    https://blog.csdn.net/OliverkingLi/article/details/81214947
58
    """
59
    
60
    def __init__(self, fig_path, model,
61
                 train_loss_path, test_loss_path, train_acc_path, test_acc_path, json_path=None, start_At=0):
62
        """
63
        训练监控初始化
64
        :param fig_path: loss store path
65
        :param model:
66
        :param json_path: Json file path
67
        :param int start_At:
68
        :return: None
69
        """
70
        
71
        super(TrainingMonitor, self).__init__()
72
        self.fig_path = fig_path + "/xxx.png"
73
        self.json_path = json_path
74
        self.start_At = start_At
75
        self.model = model
76
        self.epochs = 0
77
        
78
        self.train_loss_path = train_loss_path
79
        self.test_loss_path = test_loss_path
80
        self.train_acc_path = train_acc_path
81
        self.test_acc_path = test_acc_path
82
    
83
    def on_train_begin(self, logs={}):
84
        self.H = {}
85
        if self.json_path is not None:
86
            if os.path.exists(self.json_path):
87
                self.H = json.loads(open(self.json_path).read())
88
                if self.start_At > 0:
89
                    for k in self.H.keys():
90
                        self.H[k] = self.H[k][:self.start_At]
91
    
92
    def on_epoch_end(self, epoch, logs=None):
93
        for (k, v) in logs.items():
94
            l = self.H.get(k, [])
95
            l.append(v)
96
            self.H[k] = l
97
        if self.json_path is not None:
98
            f = open(self.json_path, 'w')
99
            f.write(json.dumps(self.H))
100
            f.close()
101
        if len(self.H["loss"]) > 1:
102
            N = np.arange(0, len(self.H["loss"]))
103
            plt.style.use("ggplot")
104
            plt.figure()
105
            plt.plot(N, self.H["loss"], label="train_loss")
106
            write_txt_file(self.H["loss"], self.train_loss_path)
107
            plt.plot(N, self.H["val_loss"], label="val_loss")
108
            write_txt_file(self.H["val_loss"], self.test_loss_path)
109
            plt.plot(N, self.H["acc"], label="train_acc")
110
            write_txt_file(self.H["acc"], self.train_acc_path)
111
            plt.plot(N, self.H["val_acc"], label="val_acc")
112
            write_txt_file(self.H["val_acc"], self.test_acc_path)
113
            plt.title("Training Loss and Accuracy [Epoch {}]".format(len(self.H["loss"])))
114
            plt.xlabel("Epoch #")
115
            plt.ylabel("Loss/Accuracy")
116
            plt.legend()
117
            plt.savefig(self.fig_path)
118
            plt.close()