a b/utils.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Tue Aug 31 17:02:33 2021
4
5
@author: m.beuque
6
"""
7
from sklearn.metrics import confusion_matrix
8
import seaborn as sns
9
from matplotlib import pyplot as plt
10
import numpy as np
11
import os
12
import pandas as pd
13
from xml.dom import minidom
14
import xml.etree.ElementTree as ET
15
16
17
18
def extract_transform_matrix(path_mis):
19
    elmt_matrix = []
20
    # parse an xml file by name
21
    mydoc = minidom.parse(path_mis)
22
    
23
    items = mydoc.getElementsByTagName('TeachPoint')
24
     
25
    for elem in items:  
26
        temp = str(elem.firstChild.data)
27
        temp = temp.split(";")
28
        val = []
29
        val.append(list(temp[0].split(",")))
30
        val.append(list(temp[1].split(",")))
31
        elmt_matrix.append(val)
32
    
33
    elmt_matrix = np.array(elmt_matrix)
34
    fixed_points_optical = elmt_matrix[:3,0]
35
    moving_points_motor =  elmt_matrix[:3,1]
36
    moving_points_optical = elmt_matrix[3:,0]
37
    fixed_points_histo = elmt_matrix[3:,1]  
38
    
39
    return  fixed_points_optical, moving_points_motor,moving_points_optical,fixed_points_histo
40
41
def process_spotlist(path_spotlist) :
42
    f = open(path_spotlist) 
43
    new_spotlist = []
44
    for line in f:
45
        if '#' not in line:
46
            elmts = line.split()
47
            temp_x = elmts[2][4:7]
48
            temp_y = elmts[2][-3:]
49
            elmts.remove(elmts[2])
50
            elmts.insert(2,temp_y)
51
            elmts.insert(2, temp_x)
52
            new_spotlist.append(elmts)  
53
    spotlist = np.array(new_spotlist, dtype = int)
54
    f.close()
55
    return spotlist
56
57
def polygon_extraction(path_mis) :
58
    ##get rois data
59
    tree = ET.parse(path_mis)
60
    root = tree.getroot()
61
    
62
    dict_roi = {}
63
    for roi in root.iter('ROI'):
64
        name_temp = roi.get('Name')
65
        roi_list = []
66
        for point in roi.getchildren() :
67
            temp_xy = point.text.split(',')
68
            roi_list.append(tuple([ int(x) for x in temp_xy ]))
69
        #array_roi = np.array(roi_list)
70
        #array_roi = array_roi.astype(int)
71
        dict_roi[name_temp] = roi_list
72
    return dict_roi
73
74
def extract_center_pos(path_mis):
75
    mydoc = minidom.parse(path_mis)
76
    
77
    items = mydoc.getElementsByTagName('View')
78
    string_centerpos = items[0].attributes['CenterPos'].value
79
    center_pos = list(string_centerpos.split(","))
80
    return np.array(center_pos, dtype = int)
81
82
def assemble_dataset_supervised_learning(full_labels,list_dataset,path_data, data_type):
83
    frames = []
84
    labels= []
85
    for j,f in enumerate(list_dataset):
86
        temp_df = pd.read_csv(os.path.join(path_data,f))
87
        selected_labels=full_labels[full_labels["slide"] ==f[14:-4]]
88
        unified_labels = list(selected_labels["unified_label"])
89
        image_names = list(selected_labels["image_name"])
90
        list_index = []
91
        selected_names = []
92
        if data_type == "stroma":
93
            for i, elmt in enumerate(unified_labels):
94
                if elmt[-1] == "g":
95
                    labels.append("gland")
96
                    list_index.append(i)
97
                    selected_names.append(image_names[i])
98
                elif elmt[-1] == "t":
99
                    labels.append("epithelial tissue")
100
                    list_index.append(i)
101
                    selected_names.append(image_names[i])
102
        if data_type == "grade":
103
            for i, elmt in enumerate(unified_labels):
104
                if elmt == "lowgrade_g":
105
                    labels.append("low grade")
106
                    list_index.append(i)
107
                    selected_names.append(image_names[i])
108
                elif elmt == "highgrade_g":
109
                    labels.append("high grade")
110
                    list_index.append(i)
111
                    selected_names.append(image_names[i])
112
                elif elmt == "healthy_g":
113
                    labels.append("non-dysplasia")
114
                    list_index.append(i)
115
                    selected_names.append(image_names[i])
116
117
        temp_df = temp_df.iloc[list_index]
118
        if not temp_df.empty:
119
            #temp_df.insert(loc=0, column='labels', value=np.array(labels))
120
            #col_dataset = [j for i in range(len(temp_df.index))]
121
            print(f[14:-4] + " has " + str(range(len(temp_df.index))) + " values")
122
            col_dataset = [f[14:-4] for i in range(len(temp_df.index))]
123
            #temp_df.insert(loc=1, column='dataset_name', value=col_dataset) 
124
            temp_df.insert(loc=1, column='dataset_name', value=col_dataset) 
125
            temp_df.insert(loc=2, column='image_name', value=selected_names) 
126
            frames.append(temp_df)
127
    print("size frames: ", len(frames))
128
    full_dataset = pd.concat(frames)
129
    del full_dataset['Unnamed: 0']
130
    return full_dataset, labels
131
132
def print_confusion_matrix(y_true, y_pred, class_names, figsize = (6,5), fontsize=14,normalize=False):
133
    cm = confusion_matrix(y_true, y_pred,class_names)
134
    # Only use the labels that appear in the data
135
    #classes = classes[unique_labels(y_true, y_pred)]
136
    if normalize:
137
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
138
        print("Normalized confusion matrix")
139
    else:
140
        print('Confusion matrix, without normalization')
141
142
    sns.set(font_scale=1.4)
143
    if normalize:
144
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
145
        print("Normalized confusion matrix")
146
    else:
147
        print('Confusion matrix, without normalization')
148
    df_cm = pd.DataFrame(
149
        cm, index=class_names, columns=class_names, 
150
    )
151
    fig = plt.figure(figsize=figsize)
152
    try:
153
        if normalize:
154
            fmt = '.2f' 
155
            heatmap = sns.heatmap(df_cm, annot=True, fmt=fmt,cmap="Blues",vmin=0, vmax=1)
156
        else:
157
            fmt = 'd' 
158
            heatmap = sns.heatmap(df_cm, annot=True, fmt=fmt,cmap="Blues",vmax=max(np.sum((y_pred==class_names[0] )*(y_true==class_names[0])),np.sum((y_pred==class_names[1] )*(y_true==class_names[1]))),vmin=0)
159
    except ValueError:
160
        raise ValueError("Confusion matrix values must be integers.")
161
    heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize)
162
    heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize)
163
    plt.ylabel('True label')
164
    plt.xlabel('Predicted label')
165
    fig.tight_layout()
166
    return fig
167
168
169