a b/Analysis/collective_pca.py
1
from sklearn.cross_decomposition import CCA
2
import numpy as np
3
import pickle
4
from scipy.io import loadmat
5
from sklearn.decomposition import PCA
6
import matplotlib.pyplot as plt
7
from scipy.ndimage import gaussian_filter1d
8
9
from sklearn.linear_model import LinearRegression, Ridge
10
from sklearn.metrics import r2_score
11
from scipy.spatial import procrustes
12
from sklearn.decomposition import PCA
13
14
import sys
15
sys.path.insert(0, '../SAC/')
16
import kinematics_preprocessing_specs
17
18
import config
19
20
parser = config.config_parser()
21
args, unknown = parser.parse_known_args()
22
23
#Load the test data of nusim
24
25
with open('../test_data/test_data.pkl', 'rb') as file:
26
    test_data = pickle.load(file)
27
    
28
print(test_data.keys())
29
30
#Get the timepoints of each condition per cycle
31
with open('../kinematics_data/kinematics.pkl', 'rb') as file:
32
    kin_train_test = pickle.load(file)
33
    
34
kin_train = kin_train_test['train']
35
kin_test = kin_train_test['test']
36
37
#First update the keys of self.kin_test
38
for cond in range(len(kin_test)):
39
    kin_test[len(kin_train) + cond] = kin_test.pop(cond)
40
    
41
kin = kin_train
42
kin.update(kin_test)
43
44
conds = [kin[cond].shape[-1] for cond in range(len(kin))]
45
total_conds = len(conds)
46
47
#Select the cycle for each condition (training conditions followed by testing): 0 for 1st cycle and so on
48
#The number of elements should be equal to num_train_conditions + num_test_conditions
49
cycles = [2, 2, 2, 2, 2, 2]
50
51
#Number of fixedsteps in the start of each condition
52
#Fix this {todo}: Get the values automatically from ..SAC/kinematics_preprocessing_specs.py
53
n_fixedsteps= args.n_fixedsteps
54
55
#Load the network activities
56
A_agent = []
57
58
for idx, cond_activity in test_data['rnn_activity'].items():
59
    act_agent = cond_activity
60
    act_agent = act_agent[n_fixedsteps + cycles[idx] * conds[idx] : n_fixedsteps + (cycles[idx]+1) * conds[idx]]
61
    print(act_agent.shape)
62
    A_agent.append(act_agent[:, :])
63
64
#Do the collective PCA for all speeds
65
nusim_pca = PCA(n_components= 3)
66
67
A_agent_c = A_agent
68
#concatenate the musim activity for all conditions
69
for i_cond in range(len(A_agent_c)):
70
    
71
    if i_cond == 0:
72
        nusim_activity_pca = A_agent_c[i_cond]
73
    else:
74
        nusim_activity_pca = np.concatenate((nusim_activity_pca, A_agent_c[i_cond]), axis=0)
75
76
nusim_activity_pca = nusim_pca.fit_transform(nusim_activity_pca)
77
78
#Plot the PCA of the activities
79
colors = plt.cm.ocean(np.linspace(0,1,8))
80
ax = plt.figure(dpi=100).add_subplot(projection='3d')
81
82
prev_cond = 0
83
for i_cond in range(len(A_agent_c)):
84
    ax.plot(nusim_activity_pca[prev_cond:prev_cond+A_agent_c[i_cond].shape[0],0], 
85
            nusim_activity_pca[prev_cond:prev_cond+A_agent_c[i_cond].shape[0], 1], 
86
            nusim_activity_pca[prev_cond:prev_cond+A_agent_c[i_cond].shape[0], 2], color= colors[i_cond])
87
    
88
    prev_cond += A_agent_c[i_cond].shape[0]
89
90
    
91
# Hide grid lines
92
ax.grid(False)
93
plt.grid(b=None)
94
95
# Hide axes ticks
96
ax.set_xticks([])
97
ax.set_yticks([])
98
ax.set_zticks([])
99
plt.axis('off')
100
101
plt.show()