[9f010e]: / Analysis / .ipynb_checkpoints / find_fp-checkpoint.py

Download this file

116 lines (86 with data), 3.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
'''
examples/torch/run_FlipFlop.py
Written for Python 3.8.17 and Pytorch 2.0.1
@ Matt Golub, June 2023
Please direct correspondence to mgolub@cs.washington.edu
'''
import pdb
import sys
import numpy as np
import torch
import pickle
PATH_TO_FIXED_POINT_FINDER = './fixed-point-finder/'
PATH_TO_HELPER = './fixed-point-finder/examples/helper/'
PATH_TO_TORCH = './fixed-point-finder/examples/torch/'
PATH_TO_SAC = '../SAC/'
sys.path.insert(0, PATH_TO_FIXED_POINT_FINDER)
sys.path.insert(0, PATH_TO_HELPER)
sys.path.insert(0, PATH_TO_TORCH)
sys.path.insert(0, PATH_TO_SAC)
from FlipFlop import FlipFlop
from FixedPointFinderTorch import FixedPointFinderTorch as FixedPointFinder
from FlipFlopData import FlipFlopData
from plot_utils import plot_fps
def find_fixed_points(model, rnn_trajectories, rnn_input):
''' Find, analyze, and visualize the fixed points of the trained RNN.
Args:
model:
Trained RNN model, as returned by uSim training.
valid_predictions: dict.
Model trajectories for training and testing conditions.
Returns:
None.
'''
NOISE_SCALE = 0.5 # Standard deviation of noise added to initial states
# N_INITS = 1024*10
N_INITS = rnn_trajectories.shape[0] * rnn_trajectories.shape[1] # The number of initial states to provide
n_hidden_units = rnn_trajectories.shape[2]
'''Fixed point finder hyperparameters. See FixedPointFinder.py for detailed
descriptions of available hyperparameters.'''
fpf_hps = {
'max_iters': 10000,
'lr_init': 1.,
'outlier_distance_scale': 10.0,
'verbose': True,
'super_verbose': True}
# Setup the fixed point finder
fpf = FixedPointFinder(model, **fpf_hps)
'''Draw random, noise corrupted samples of those state trajectories
to use as initial states for the fixed point optimizations.'''
initial_states = fpf.sample_states(rnn_trajectories,
n_inits=N_INITS,
noise_scale=NOISE_SCALE)
# Study the system in the absence of input pulses (e.g., all inputs are 0)
inputs = np.zeros([1, n_hidden_units])
# inputs = rnn_input.reshape(-1, n_hidden_units)
# Run the fixed point finder
unique_fps, all_fps = fpf.find_fixed_points(initial_states, inputs)
# Visualize identified fixed points with overlaid RNN state trajectories
# All visualized in the 3D PCA space fit the the example RNN states.
fig = plot_fps(unique_fps, rnn_trajectories,
plot_batch_idx=list(range(6)),
plot_start_time=0)
def main():
# Step 1: Load the uSim RNN model, RNN trajectories and RNN input
PATH_TO_MODEL = '../checkpoint/actor_rnn_best_fpf.pth'
actor_rnn = torch.load(PATH_TO_MODEL)
model = actor_rnn
#Load the test data
with open('../test_data/test_data.pkl', 'rb') as file:
test_data = pickle.load(file)
#Get the uSim RNN hidden trajectories and inputs
rnn_trajectories = []
for cond in range(len(test_data['rnn_activity'])):
rnn_trajectories.append(test_data['rnn_activity'][cond])
rnn_trajectories = np.array(rnn_trajectories) #[n_conds, n_timepoints, n_hidden_units]
rnn_input = []
for cond in range(len(test_data['rnn_input_fp'])):
rnn_input.append(test_data['rnn_input_fp'][cond])
rnn_input = np.array(rnn_input) #[n_conds, n_timepoints, n_hidden_units]
# STEP 2: Find, analyze, and visualize the fixed points of the trained RNN
find_fixed_points(model, rnn_trajectories, rnn_input)
print('Entering debug mode to allow interaction with objects and figures.')
print('You should see a figure with:')
pdb.set_trace()
if __name__ == '__main__':
main()