--- a +++ b/Analysis/find_fp.py @@ -0,0 +1,116 @@ +''' +examples/torch/run_FlipFlop.py +Written for Python 3.8.17 and Pytorch 2.0.1 +@ Matt Golub, June 2023 +Please direct correspondence to mgolub@cs.washington.edu +''' + +import pdb +import sys +import numpy as np +import torch +import pickle + + +PATH_TO_FIXED_POINT_FINDER = './fixed-point-finder/' +PATH_TO_HELPER = './fixed-point-finder/examples/helper/' +PATH_TO_TORCH = './fixed-point-finder/examples/torch/' +PATH_TO_SAC = '../SAC/' +sys.path.insert(0, PATH_TO_FIXED_POINT_FINDER) +sys.path.insert(0, PATH_TO_HELPER) +sys.path.insert(0, PATH_TO_TORCH) +sys.path.insert(0, PATH_TO_SAC) + +from FlipFlop import FlipFlop +from FixedPointFinderTorch import FixedPointFinderTorch as FixedPointFinder +from FlipFlopData import FlipFlopData +from plot_utils import plot_fps + + +def find_fixed_points(model, rnn_trajectories, rnn_input): + ''' Find, analyze, and visualize the fixed points of the trained RNN. + + Args: + model: + + Trained RNN model, as returned by uSim training. + + valid_predictions: dict. + + Model trajectories for training and testing conditions. + + Returns: + None. + ''' + + NOISE_SCALE = 0.5 # Standard deviation of noise added to initial states + # N_INITS = 1024*10 + N_INITS = rnn_trajectories.shape[0] * rnn_trajectories.shape[1] # The number of initial states to provide + + n_hidden_units = rnn_trajectories.shape[2] + + '''Fixed point finder hyperparameters. See FixedPointFinder.py for detailed + descriptions of available hyperparameters.''' + fpf_hps = { + 'max_iters': 10000, + 'lr_init': 1., + 'outlier_distance_scale': 10.0, + 'verbose': True, + 'super_verbose': True} + + # Setup the fixed point finder + fpf = FixedPointFinder(model, **fpf_hps) + + '''Draw random, noise corrupted samples of those state trajectories + to use as initial states for the fixed point optimizations.''' + initial_states = fpf.sample_states(rnn_trajectories, + n_inits=N_INITS, + noise_scale=NOISE_SCALE) + + # Study the system in the absence of input pulses (e.g., all inputs are 0) + inputs = np.zeros([1, n_hidden_units]) + # inputs = rnn_input.reshape(-1, n_hidden_units) + + # Run the fixed point finder + unique_fps, all_fps = fpf.find_fixed_points(initial_states, inputs) + + # Visualize identified fixed points with overlaid RNN state trajectories + # All visualized in the 3D PCA space fit the the example RNN states. + fig = plot_fps(unique_fps, rnn_trajectories, + plot_batch_idx=list(range(6)), + plot_start_time=0) + +def main(): + + # Step 1: Load the uSim RNN model, RNN trajectories and RNN input + PATH_TO_MODEL = '../checkpoint/actor_rnn_best_fpf.pth' + actor_rnn = torch.load(PATH_TO_MODEL) + model = actor_rnn + + #Load the test data + with open('../test_data/test_data.pkl', 'rb') as file: + test_data = pickle.load(file) + + #Get the uSim RNN hidden trajectories and inputs + rnn_trajectories = [] + for cond in range(len(test_data['rnn_activity'])): + rnn_trajectories.append(test_data['rnn_activity'][cond]) + + rnn_trajectories = np.array(rnn_trajectories) #[n_conds, n_timepoints, n_hidden_units] + + rnn_input = [] + for cond in range(len(test_data['rnn_input_fp'])): + rnn_input.append(test_data['rnn_input_fp'][cond]) + + rnn_input = np.array(rnn_input) #[n_conds, n_timepoints, n_hidden_units] + + # STEP 2: Find, analyze, and visualize the fixed points of the trained RNN + find_fixed_points(model, rnn_trajectories, rnn_input) + + print('Entering debug mode to allow interaction with objects and figures.') + print('You should see a figure with:') + + pdb.set_trace() + +if __name__ == '__main__': + main() \ No newline at end of file