Diff of /Analysis/find_fp.py [000000] .. [9f010e]

Switch to unified view

a b/Analysis/find_fp.py
1
'''
2
examples/torch/run_FlipFlop.py
3
Written for Python 3.8.17 and Pytorch 2.0.1
4
@ Matt Golub, June 2023
5
Please direct correspondence to mgolub@cs.washington.edu
6
'''
7
8
import pdb
9
import sys
10
import numpy as np
11
import torch
12
import pickle
13
14
15
PATH_TO_FIXED_POINT_FINDER = './fixed-point-finder/'
16
PATH_TO_HELPER = './fixed-point-finder/examples/helper/'
17
PATH_TO_TORCH = './fixed-point-finder/examples/torch/'
18
PATH_TO_SAC = '../SAC/'
19
sys.path.insert(0, PATH_TO_FIXED_POINT_FINDER)
20
sys.path.insert(0, PATH_TO_HELPER)
21
sys.path.insert(0, PATH_TO_TORCH)
22
sys.path.insert(0, PATH_TO_SAC)
23
24
from FlipFlop import FlipFlop
25
from FixedPointFinderTorch import FixedPointFinderTorch as FixedPointFinder
26
from FlipFlopData import FlipFlopData
27
from plot_utils import plot_fps
28
29
30
def find_fixed_points(model, rnn_trajectories, rnn_input):
31
    ''' Find, analyze, and visualize the fixed points of the trained RNN.
32
33
    Args:
34
        model: 
35
36
            Trained RNN model, as returned by uSim training.
37
38
        valid_predictions: dict.
39
40
            Model trajectories for training and testing conditions.
41
42
    Returns:
43
        None.
44
    '''
45
46
    NOISE_SCALE = 0.5 # Standard deviation of noise added to initial states
47
    # N_INITS = 1024*10
48
    N_INITS = rnn_trajectories.shape[0] * rnn_trajectories.shape[1] # The number of initial states to provide
49
50
    n_hidden_units = rnn_trajectories.shape[2]
51
52
    '''Fixed point finder hyperparameters. See FixedPointFinder.py for detailed
53
    descriptions of available hyperparameters.'''
54
    fpf_hps = {
55
        'max_iters': 10000,
56
        'lr_init': 1.,
57
        'outlier_distance_scale': 10.0,
58
        'verbose': True, 
59
        'super_verbose': True}
60
61
    # Setup the fixed point finder
62
    fpf = FixedPointFinder(model, **fpf_hps)
63
64
    '''Draw random, noise corrupted samples of those state trajectories
65
    to use as initial states for the fixed point optimizations.'''
66
    initial_states = fpf.sample_states(rnn_trajectories,
67
        n_inits=N_INITS,
68
        noise_scale=NOISE_SCALE)
69
70
    # Study the system in the absence of input pulses (e.g., all inputs are 0)
71
    inputs = np.zeros([1, n_hidden_units])
72
    # inputs = rnn_input.reshape(-1, n_hidden_units)
73
74
    # Run the fixed point finder
75
    unique_fps, all_fps = fpf.find_fixed_points(initial_states, inputs)
76
77
    # Visualize identified fixed points with overlaid RNN state trajectories
78
    # All visualized in the 3D PCA space fit the the example RNN states.
79
    fig = plot_fps(unique_fps, rnn_trajectories,
80
        plot_batch_idx=list(range(6)),
81
        plot_start_time=0)
82
83
def main():
84
85
    # Step 1: Load the uSim RNN model, RNN trajectories and RNN input
86
    PATH_TO_MODEL = '../checkpoint/actor_rnn_best_fpf.pth' 
87
    actor_rnn = torch.load(PATH_TO_MODEL)
88
    model = actor_rnn
89
90
    #Load the test data
91
    with open('../test_data/test_data.pkl', 'rb') as file:
92
        test_data = pickle.load(file)
93
94
    #Get the uSim RNN hidden trajectories and inputs
95
    rnn_trajectories = []
96
    for cond in range(len(test_data['rnn_activity'])):
97
        rnn_trajectories.append(test_data['rnn_activity'][cond])
98
99
    rnn_trajectories = np.array(rnn_trajectories)  #[n_conds, n_timepoints, n_hidden_units]
100
101
    rnn_input = []
102
    for cond in range(len(test_data['rnn_input_fp'])):
103
        rnn_input.append(test_data['rnn_input_fp'][cond])
104
105
    rnn_input = np.array(rnn_input) #[n_conds, n_timepoints, n_hidden_units]
106
107
    # STEP 2: Find, analyze, and visualize the fixed points of the trained RNN
108
    find_fixed_points(model, rnn_trajectories, rnn_input)
109
110
    print('Entering debug mode to allow interaction with objects and figures.')
111
    print('You should see a figure with:')
112
113
    pdb.set_trace()
114
115
if __name__ == '__main__':
116
    main()