Switch to unified view

a b/Analysis/jPCA_ForDistribution/phaseSpace.m
1
% For making publication quality rosette plots
2
% useage:
3
%       phaseSpace(Projection, Summary)
4
%       phaseSpace(Projection, Summary, params)
5
%
6
% You can limit the conditions you plot either by
7
%   1) including only some entries of Projection (e.g., Projection(1:2))
8
%   or 2) by using params.conds2plot to restrict (e.g., params.conds2plot = 1:2);
9
%   In the former case, scaling etc is based just on the passed points.
10
%   In the latter scaling is based on all the points.
11
%
12
% outputs are [colorStruct, hf, haxP, vaxP] = phaseSpace(Projection, Summary)
13
%   
14
%   'colorStruct' is a structure (one per dataset) of cells of linecolors (one per condition)
15
%   that you might wish to pass to another function (e.g., one that plots the rosette PSTH's 
16
%   or the hand trajectories).
17
%
18
%    'hf' is the fig#, 'haxP' and 'vaxP' are the axis parameters.
19
%
20
%   *** ALL of the outputs, except the last, pertain only the LAST graph plotted. ***
21
%
22
% rosetteData comes from multiRosetteScript
23
%
24
% params can have the following fields:
25
%       .times    These override the default times (those corresponding to scores; e.g., the orginal
26
%               times that were used to build the space.  If empty, the defaults are used.
27
%                 Note that zero is movement onset.  Thus .times will probably start strongly
28
%               negatively.  
29
%                 Thus, you might pass it -1550:10:150 if you wanted to start way back at the
30
%               beginning.
31
%                 A nice feature is that only those times that match times in 'scoresExtraTime' are
32
%               used.  Thus, if you pass -100000:1000000 things will still work fine.
33
%               Scalings are based on all times in Projection.projAllTimes.  This is nice for movies
34
%               as the scaling won't change as a function of the times you plot.
35
%
36
%       .planes2plot   list of the jPC planes you want plotted.  Default is [1].  [1,2] would also be reasonable.
37
%
38
%       .arrowSize        The default is 5
39
%       .arrowGain        FOR MOVIES: sets velocity dependence of arrow size (0 to not grow when faster).
40
%       .plotPlanEllipse  Controls whether the ellipse is plotted.  The default is 'true'
41
%       .useAxes          whether axes should be plotted.  Default is 'true'
42
%       .useLabel         whether to label with dataset. Default is 'true'
43
%       .planMarkerSize   size of the plan dot.  Default is 6.
44
%       .lineWidth        width of the trajectories.  Default is 0.85.
45
%       .arrowMinVel      minimum velocity for plotting an arrow.
46
%       .rankType         default is 'eig', but you can override with 'varCapt'.  The first plane will
47
%                         then be the jPC plane that captured the most variance (often associated with the 
48
%                         largest eigenvalue but not always
49
%       .conds2plot       which conditions to plot (scalings will still be based on all the conds in 'Projection')
50
%       .substRawPCs      use PC projections rather than jPC projections
51
%       .crossCondMean    if present and == 1, plot the cross condition mean in cyan.
52
%       .reusePlot        if present and == 1, do cla then reuse the plot
53
%       .dataRanges       normally this is set automatically, but you can decide yourself what the
54
%                         range should be.  You should supply one entry per plane to be plotted.  You can also supply
55
%                         just the first, and then the defaults will be used after that.
56
%
57
function [colorStruct, haxP, vaxP] = phaseSpace(Projection, Summary, params)
58
59
60
%% some basic parameters
61
axLimScale = 1.35; 
62
axisSeparation = 0.20;  % separated by 20% of the maximum excursion (may need more if plotting future times, which aren't used to compute farthestLeft or farthestDown)
63
64
numPlanes = length(Summary.varCaptEachPlane);  % total number of planes provided (may only plot a subset)
65
66
%% set defaults and override if 'params' is included as an argument
67
68
% allows for the use of times other than the original ones that correspond to 'scores' (those that
69
% were used to create the projection and do the analysis)
70
overrideTimes = [];
71
if exist('params', 'var') && isfield(params,'times')
72
    overrideTimes = params.times;
73
end
74
75
arrowSize = 5;
76
if exist('params', 'var') && isfield(params,'arrowSize')
77
    arrowSize = params.arrowSize;
78
end
79
80
arrowGain = 0;
81
if exist('params', 'var') && isfield(params,'arrowGain')
82
    arrowGain = params.arrowGain;
83
end
84
85
% Default is we plot the ellipse if we have 6 or more conditions
86
% NOTE: we still plot it even if asked to only plot one cond, so long as we HAVE more than 6 to
87
% build the ellipse off.  It matters whether length(Projection) >= 6, not whether conds2plot >= 6
88
if length(Projection) >= 6
89
    plotPlanEllipse = true;
90
else
91
    plotPlanEllipse = false; 
92
    axLimScale = 1.3*axLimScale;
93
end
94
if exist('params', 'var') && isfield(params,'plotPlanEllipse')
95
    plotPlanEllipse = params.plotPlanEllipse;
96
end
97
98
useAxes = true;
99
if exist('params', 'var') && isfield(params,'useAxes')
100
    useAxes = params.useAxes;
101
end
102
103
useLabel = true;
104
if exist('params', 'var') && isfield(params,'useLabel')
105
    useLabel = params.useLabel;
106
end
107
108
planMarkerSize = 6;
109
if exist('params', 'var') && isfield(params,'planMarkerSize')
110
    planMarkerSize = params.planMarkerSize;
111
end
112
113
lineWidth = 0.85;
114
if exist('params', 'var') && isfield(params,'lineWidth')
115
    lineWidth = params.lineWidth;
116
end
117
118
arrowMinVel = [];
119
if exist('params', 'var') && isfield(params,'arrowMinVel')
120
    arrowMinVel = params.arrowMinVel;
121
end
122
123
planes2plot = [1];  % this is a list of which planes to plot 
124
if exist('params', 'var') && isfield(params,'planes2plot')
125
    planes2plot = params.planes2plot;
126
end
127
128
rankType = 'eig';
129
if exist('params', 'var') && isfield(params,'rankType')
130
    rankType = params.rankType;
131
end
132
133
reusePlot = 0;
134
if exist('params', 'var') && isfield(params,'reusePlot')
135
    reusePlot = params.reusePlot;
136
end
137
if length(planes2plot) > 1, reusePlot = 0; end  % cant reuse if we are plotting more than one thing.
138
139
numConds = length(Projection);
140
conds2plot = 1:numConds;
141
if exist('params', 'var') && isfield(params,'conds2plot')
142
    if ~strcmp(params.conds2plot,'all')
143
        conds2plot = params.conds2plot;
144
    end
145
end
146
147
% if asked, plot the cross-condition mean on the same plot.
148
crossCondMean = false;
149
if exist('params', 'var') && isfield(params,'crossCondMean')
150
    crossCondMean = params.crossCondMean;
151
end
152
153
154
% If asked, substitue the raw PC projections.
155
substRawPCs = 0;
156
if exist('params', 'var') && isfield(params,'substRawPCs') && params.substRawPCs
157
    substRawPCs = 1;
158
    % just overwrite
159
    for c = 1:numConds
160
        Projection(c).proj = Projection(c).tradPCAproj;
161
        Projection(c).projAllTimes = Projection(c).tradPCAprojAllTimes;
162
    end
163
    Summary.varCaptEachPlane = sum(reshape(Summary.varCaptEachPC,2,numPlanes));
164
end
165
166
if strcmp(rankType, 'varCapt')  && substRawPCs == 0  % we WONT reorder if they were PCs
167
    [~, sortIndices] = sort(Summary.varCaptEachPlane,'descend');
168
    planes2plot_Orig = planes2plot;  % keep this so we can label the plane appropriately
169
    planes2plot = sortIndices(planes2plot);  % get the asked for planes, but by var accounted for rather than eigenvalue
170
end
171
172
% the range of the data will set the size of the plot unless you manually override
173
dataRanges = max(abs(vertcat(Projection.proj)));
174
dataRanges = max(reshape(dataRanges,2,numPlanes));  % one range per plane
175
176
if exist('params', 'var') && isfield(params,'dataRanges')
177
    for i = 1:length(params.dataRanges)
178
        dataRanges(i) = params.dataRanges(i);  % only override those values that are specified
179
    end
180
end
181
    
182
183
184
arrowEdgeColor = 'k';
185
186
for pindex = 1:length(planes2plot)
187
    
188
    % get some useful indices
189
    plane = planes2plot(pindex);  % which plane to plot
190
    d2 = 2*plane;  % indices into the dimensions
191
    d1 = d2-1;
192
193
    % set the limits of the figure
194
    axLim = axLimScale * dataRanges(plane) * [-1 1 -1 1];
195
    axisLength = 0.5;
196
197
    
198
    for c = 1:numConds
199
        % Always taken from the 1st element (NOT the first that will be plotted)
200
        % This way the ellipse doesn't depend on which times you choose to plot 
201
        planData(c,:) = Projection(c).proj(1,[d1,d2]);  
202
    end
203
    % need this for ellipse plotting
204
    if numConds > 1   % used to be only if plotPlanEllipse==1.  Now we always use the same scaling regardless of whether we use the plan ellipse.
205
        ellipseRadii = 2*var(planData).^0.5;  % we may plot an ellipse for the plan activity
206
207
        % these will be altered further below based on how far the data extends left and down
208
        farthestLeft = -ellipseRadii(1);  % used figure out how far the axes need to be offset
209
        farthestDown = -ellipseRadii(2);  % used figure out how far the axes need to be offset
210
    else 
211
        temp = vertcat(Projection.proj);
212
        temp = temp(:,[d1,d2]);
213
        farthestLeft = -1.2*max(temp(:));  % used figure out how far the axes need to be offset
214
        farthestDown = -1.2*max(temp(:));  % used figure out how far the axes need to be offset
215
    end
216
217
    %% deal with the color scheme
218
219
    % ** colors graded based on PLAN STATE
220
    % These do NOT depend on which times you choose to plot (only on which time is first in Projection.proj).
221
    htmp = redgreencmap(numConds, 'interpolation', 'linear');
222
    [~,newColorIndices] = sort(planData(:,1));
223
224
    htmp(newColorIndices,:) = htmp;
225
226
    for c = 1:numConds  % cycle through conditions, and assign that condition's color
227
        lineColor{c} = htmp(c,:);
228
        arrowFaceColor{c} = htmp(c,:);
229
        planMarkerColor{c} = htmp(c,:);
230
    end
231
232
    % override colors if asked
233
    if exist('params', 'var') && isfield(params,'colors')
234
        lineColor = params.colors;
235
        arrowFaceColor = params.colors;
236
        planMarkerColor = params.colors;
237
        disp('hi');
238
    end
239
    
240
    colorStruct(pindex).colors = lineColor;
241
242
    %% Plot the rosette itself
243
    if reusePlot == 0, blankFigure(axLim); else cla; end
244
245
    % first deal with the ellipse for the plan variance (we want this under the rest of the data)
246
    if plotPlanEllipse, circle([0 0], ellipseRadii, 0.6*[1 1 1], 1); end
247
248
    % cycle through conditions
249
    for c = 1:numConds
250
251
        if isempty(overrideTimes)  % if we are going with the original times (those that were used to create the projection and do the analysis)        
252
            P1 = Projection(c).proj(:,d1);
253
            P2 = Projection(c).proj(:,d2);       
254
        else
255
            useTimes = ismember(Projection(c).allTimes, overrideTimes);
256
            P1 = Projection(c).projAllTimes(useTimes,d1);
257
            P2 = Projection(c).projAllTimes(useTimes,d2);
258
        end
259
260
        if ismember(c,conds2plot)
261
            plot(P1, P2, 'color', lineColor{c}, 'lineWidth', lineWidth);
262
263
            if planMarkerSize>0
264
                plot(P1(1), P2(1), 'ko', 'markerSize', planMarkerSize, 'markerFaceColor', planMarkerColor{c});
265
            end
266
267
            % for arrow, figure out last two points, and (if asked) supress the arrow if velocity is
268
            % below a threshold.
269
            penultimatePoint = [P1(end-1), P2(end-1)];
270
            lastPoint = [P1(end), P2(end)];
271
            vel = norm(lastPoint - penultimatePoint);
272
            if isempty(arrowMinVel) || vel > arrowMinVel
273
                aSize = arrowSize + arrowGain * vel;  % if asked (e.g. for movies) arrow size may grow with vel
274
                arrowMMC(penultimatePoint, lastPoint, [], aSize, axLim, arrowFaceColor{c}, arrowEdgeColor);
275
            else
276
                plot(lastPoint(1), lastPoint(2), 'ko', 'markerSi', arrowSize, 'markerFac', arrowFaceColor{c}, 'markerEdge', arrowEdgeColor);
277
            end
278
        end
279
280
        % axis locations will be based on the original set of times used to make the scores
281
        % and not on the actual times used.  Here we get the leftmost and bottommost point
282
        if isfield(Projection, 'projAllTimes')
283
            farthestLeft = min(farthestLeft, min(Projection(c).projAllTimes(:,d1)));
284
            farthestDown = min(farthestDown, min(Projection(c).projAllTimes(:,d2)));
285
        else
286
            farthestLeft = min(farthestLeft, min(Projection(c).proj(:,d1)));
287
            farthestDown = min(farthestDown, min(Projection(c).proj(:,d2)));
288
        end
289
    end
290
291
    plot(0,0,'b+', 'markerSi', 7.5);  % plot a central cross
292
    
293
    
294
    %% if asked we will also plot the cross condition mean
295
    if crossCondMean && length(Summary.crossCondMean) > 1
296
        meanColor = [0 1 1];
297
        
298
        if isempty(overrideTimes)  % if we are going with the original times (those that were used to create the projection and do the analysis)        
299
            P1 = Summary.crossCondMean(:,d1);
300
            P2 = Summary.crossCondMean(:,d2);       
301
        else
302
            useTimes = ismember(Projection(c).allTimes, overrideTimes);
303
            P1 = Summary.crossCondMeanAllTimes(useTimes,d1);
304
            P2 = Summary.crossCondMeanAllTimes(useTimes,d2);
305
        end
306
        
307
        plot(P1, P2, 'color', meanColor, 'lineWidth', 1.2*lineWidth);  % make slightly thicker than for rest of data.
308
        if planMarkerSize>0
309
            plot(P1(1), P2(1), 'ko', 'markerSize', planMarkerSize, 'markerFaceColor', meanColor);
310
        end
311
        
312
        % for arrow, figure out last two points, and (if asked) supress the arrow if velocity is
313
        % below a threshold.
314
        penultimatePoint = [P1(end-1), P2(end-1)];
315
        lastPoint = [P1(end), P2(end)];
316
        vel = norm(lastPoint - penultimatePoint);
317
318
        aSize = arrowSize + arrowGain * vel;  % if asked (e.g. for movies) arrow size may grow with vel
319
        arrowMMC(penultimatePoint, lastPoint, [], aSize, axLim, meanColor, arrowEdgeColor);
320
  
321
    end
322
323
    %% make axes
324
    if useAxes
325
        clear axisParams;
326
327
        extraSeparation = axisSeparation*(min(farthestDown,farthestLeft));
328
329
        % general axis parameters
330
        axisParams.tickLocations = [-axisLength, 0, axisLength];
331
        axisParams.longTicks = 0;
332
        axisParams.fontSize = 10.5;
333
334
        % horizontal axis
335
        axisParams.axisOffset = farthestDown + extraSeparation;
336
        axisParams.axisLabel = 'projection onto jPC_1 (a.u.)';
337
        axisParams.axisOrientation = 'h';
338
        haxP = AxisMMC(-axisLength, axisLength, axisParams);
339
340
        % vertical axis
341
        axisParams.axisOffset = farthestLeft + extraSeparation;
342
        axisParams.axisLabel = 'projection onto jPC_2 (a.u.)';
343
        axisParams.axisOrientation = 'v';
344
        axisParams.axisLabelOffset = 1.9*haxP.axisLabelOffset;
345
        vaxP = AxisMMC(-axisLength, axisLength, axisParams);
346
    end
347
348
    
349
    % plot a label at the top
350
    if substRawPCs == 0, planeType = 'jPCA'; else planeType = 'PCA'; end
351
    if useLabel
352
        if substRawPCs == 1
353
            titleText = sprintf('raw PCA plane %d', plane);
354
        elseif strcmp(rankType, 'varCapt')
355
            letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ';
356
            titleText = sprintf('jPCA plane %s (var capt ranked)', letters(planes2plot_Orig(pindex)));
357
        else
358
            titleText = sprintf('jPCA plane %d (eigval ranked)', plane);
359
        end
360
        titleText2 = sprintf('%d%% of var captured', round(100*Summary.varCaptEachPlane(plane)));
361
        text(0,0.99*axLim(4),titleText, 'horizo', 'center');
362
        text(0,0.88*axLim(4),titleText2, 'horizo', 'center', 'fontSize', 8.5);
363
    end
364
365
366
367
end  % done looping through planes
368
369
end  % end of the main function
370
371
372
373