Diff of /utils.py [000000] .. [4782c6]

Switch to unified view

a b/utils.py
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
# @Time    : 2021/8/8 16:21
4
# @Author  : Li Xiao
5
# @File    : utils.py
6
import pandas as pd
7
import numpy as np
8
9
def load_data(adj, fea, lab, threshold=0.005):
10
    '''
11
    :param adj: the similarity matrix filename
12
    :param fea: the omics vector features filename
13
    :param lab: sample labels  filename
14
    :param threshold: the edge filter threshold
15
    '''
16
    print('loading data...')
17
    adj_df = pd.read_csv(adj, header=0, index_col=None)
18
    fea_df = pd.read_csv(fea, header=0, index_col=None)
19
    label_df = pd.read_csv(lab, header=0, index_col=None)
20
21
    if adj_df.shape[0] != fea_df.shape[0] or adj_df.shape[0] != label_df.shape[0]:
22
        print('Input files must have same samples.')
23
        exit(1)
24
25
    adj_df.rename(columns={adj_df.columns.tolist()[0]: 'Sample'}, inplace=True)
26
    fea_df.rename(columns={fea_df.columns.tolist()[0]: 'Sample'}, inplace=True)
27
    label_df.rename(columns={label_df.columns.tolist()[0]: 'Sample'}, inplace=True)
28
29
    #align samples of different data
30
    adj_df.sort_values(by='Sample', ascending=True, inplace=True)
31
    fea_df.sort_values(by='Sample', ascending=True, inplace=True)
32
    label_df.sort_values(by='Sample', ascending=True, inplace=True)
33
34
    print('Calculating the laplace adjacency matrix...')
35
    adj_m = adj_df.iloc[:, 1:].values
36
    #The SNF matrix is a completed connected graph, it is better to filter edges with a threshold
37
    adj_m[adj_m<threshold] = 0
38
39
    # adjacency matrix after filtering
40
    exist = (adj_m != 0) * 1.0
41
    #np.savetxt('result/adjacency_matrix.csv', exist, delimiter=',', fmt='%d')
42
43
    #calculate the degree matrix
44
    factor = np.ones(adj_m.shape[1])
45
    res = np.dot(exist, factor)     #degree of each node
46
    diag_matrix = np.diag(res)  #degree matrix
47
    #np.savetxt('result/diag.csv', diag_matrix, delimiter=',', fmt='%d')
48
49
    #calculate the laplace matrix
50
    d_inv = np.linalg.inv(diag_matrix)
51
    adj_hat = d_inv.dot(exist)
52
53
    return adj_hat, fea_df, label_df
54
55
def accuracy(output, labels):
56
    pred = output.max(1)[1].type_as(labels)
57
    correct = pred.eq(labels).double()
58
    correct = correct.sum()
59
    return correct / len(labels)