Diff of /STM_HFS.m [000000] .. [d8e26d]

Switch to unified view

a b/STM_HFS.m
1
function [Sol_MT] = STM_HFS(Xs, ys, Lambda, opts)
2
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
3
%% Implementation of the sequential HFS rule for STM
4
%% input:
5
%         Xs:
6
%            Xs{i} stores the data matrix of the i-th task, each column corresponds to a feature
7
%            each row corresponds to a data instance
8
%
9
%         ys:
10
%            ys{i} strores the response vector of the i-th task
11
%
12
%         Lambda:
13
%            the parameter values of lambda
14
%
15
%         opts:
16
%            settings for the solver
17
%% output:
18
%         Sol:
19
%              the solution; Sol(:,:,i) stores the the
20
%              solution for the ith values in Lambda
21
%
22
%% For any problem, please contact Weizhong Zhang (zhangweizhongzju@gmail.com)
23
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
24
25
%%
26
27
% -------------------------- pass parameters ---------------------------- %
28
p = size(Xs{1}, 2);
29
T_num = length(Xs); % number of tasks
30
npar = length(Lambda); % number of parameter values of lambda
31
ind = opts.ind;
32
ind_MT = TreeTransform(ind, T_num);
33
%clear('ind');
34
opts.init=1;        % starting from a zero point
35
if opts.tFlag==2
36
    funVal = opts.funVal;
37
end
38
39
40
41
% --------------------- recover tree structure for Multi Task-------------------------- %
42
eind_MT = find(ind_MT(2,:) == p*T_num);
43
if ind_MT(1,1) == -1 % find the depth of the tree
44
    d_MT = length(eind_MT);
45
    nnl_MT = [p*T_num,eind_MT(1)-1,diff(eind_MT)]; % number of nodes per layer
46
    nnind_MT = [0,1,eind_MT]; % ind(1,nnind1(i)+1:nnind(i+1)) stores the node from the ith layer
47
else
48
    d_MT = length(eind_MT)-1;
49
    nnl_MT = [eind_MT(1),diff(eind_MT)];
50
    nnind_MT = [0,eind_MT];
51
end
52
53
54
% --------------------- initialize the output --------------------------- %
55
Sol_MT = zeros(p,T_num,npar);
56
ind_zf_MT = false(p*T_num,d_MT,npar);
57
tsolver_MT = zeros(1,npar);%should be changed
58
tscreen_MT = zeros(1,npar);
59
60
% ------------------- compute the effective region of lambda ------------ %
61
Xsys_MT = Xsys_MT_cal(Xs, ys);
62
%lambda_max=findLambdaMax(X'*y, p, ind, size(ind,2)); % why?
63
lambda_max=findLambdaMax(Xsys_MT, p*T_num, ind_MT, size(ind_MT,2));
64
%lambda_max1=findLambdaMax(Xsys_MT, p*T_num, ind, size(ind,2));
65
if opts.rFlag == 1
66
    Lambda = Lambda * lambda_max;
67
    opts.rFlag = 0;
68
end
69
[Lambdav,Lambda_ind] = sort(Lambda,'descend');
70
71
72
73
74
% --------- compute the norm of each feature and each submatrix ---------
75
Xnorm_MT = zeros(size(Xs{1},2)*T_num,1);
76
idx_row = [0:size(Xs{1},2)-1]*T_num+1;
77
for idx_t = 1:T_num
78
    Xnorm_MT(idx_row) = (sqrt(sum(Xs{idx_t}.^2,1)))';
79
    idx_row = idx_row +1;
80
end
81
ng_MT = size(ind_MT,2);
82
Xgnorm_MT = zeros(1,ng_MT);
83
gind_MT = zeros(d_MT,p*T_num);
84
if ind_MT(1,1)==-1
85
    j = 2;
86
    l = 2;
87
else
88
    l = 1;
89
    j = 1;
90
end
91
k = 1;
92
for i = j : ng_MT-1
93
    for idx_t = 1: length(Xs)
94
        if(idx_t ==1)
95
            Xgnorm_MT(i) = norm(Xs{idx_t}(:,((ind_MT(1,i)-1)/T_num+1):(ind_MT(2,i)/T_num)));
96
        else
97
            Xgnorm_MT(i)= max(Xgnorm_MT(i),norm(Xs{idx_t}(:,((ind_MT(1,i)-1)/T_num+1):(ind_MT(2,i)/T_num))));
98
        end
99
    end
100
    gind_MT(l,ind_MT(1,i):ind_MT(2,i)) = k;
101
    k = k + 1;
102
    if ind_MT(2,i)==p*T_num
103
        k = 1;
104
        l = l + 1;
105
    end
106
end
107
108
% ------- construct sparse matrix to vectorize the computation ----------
109
Gind_MT = cell(1,d_MT);
110
if ind_MT(1,1) == -1
111
    j = 2;
112
else
113
    j = 1;
114
end
115
for i = j:d_MT
116
    Gind_MT{1,i} = sparse(gind_MT(i,:),1:p*T_num,ones(1,p*T_num),nnl_MT(i),p*T_num);
117
end
118
119
120
121
% --------------- put Xgnorm and weights in tree structure ---------------
122
XgnormTree_MT = zeros(p*T_num,d_MT);
123
weightTree_MT = zeros(p*T_num,d_MT);
124
for i = 1:d_MT
125
    if ind_MT(1,1)==-1&&i==1
126
        XgnormTree_MT(:,i) = Xnorm_MT;
127
        weightTree_MT(:,i) = ind_MT(3,1);
128
    else
129
        G_MT = Gind_MT{1,i};
130
        XgnormTree_MT(:,i) = G_MT'*(Xgnorm_MT(nnind_MT(i)+1:nnind_MT(i+1)))';
131
        weightTree_MT(:,i) = G_MT'*(ind_MT(3,nnind_MT(i)+1:nnind_MT(i+1)))';
132
    end
133
end
134
135
136
137
138
139
% ----------- solve STM sequentially via HFS ------------------
140
opts.rFlag = 0; % the input parameters are their true values
141
142
s_MT = zeros(p*T_num,1);
143
c2_MT = zeros(p*T_num,1);
144
minn_MT = zeros(p*T_num,1);
145
146
rLambdav = 1./Lambdav;
147
lambdap = Lambdav(1);
148
rlambdap = rLambdav(1);
149
vnormTree_MT = zeros(p*T_num,d_MT);
150
tol0 = 1e-12;
151
for i = 1:npar
152
    
153
    %fprintf('in HFS step: %d\n',i);
154
    lambdac = Lambdav(1,i);
155
    rlambdac = rLambdav(1,i);
156
    if lambdac>=lambda_max
157
        ind_zf_MT(:,:,Lambda_ind(i)) = true;
158
    else
159
        starts_screening =tic;
160
        if lambdap==lambda_max
161
            theta_MT = [];
162
            for idx_t = 1: length(ys)
163
                theta_MT = [theta_MT; ys{idx_t}*rlambdap];
164
            end
165
            
166
            z_MT = Xsys_MT*rlambdap;
167
            [u_MT, v_MT] = Hierarchical_Projection( z_MT, ind_MT, nnind_MT, Gind_MT );
168
            
169
            if ind_MT(3,end)==0
170
                weightd_MT = (ind_MT(3,nnind_MT(d_MT)+1:nnind_MT(d_MT+1)))';
171
                [~,Xmxind_MT] = min(abs(Gind_MT{1,d_MT}*(v_MT(:,d_MT).*v_MT(:,d_MT))-weightd_MT.*weightd_MT));
172
                idx_column_MT  = [ind_MT(1,nnind_MT(d_MT)+Xmxind_MT):ind_MT(2,nnind_MT(d_MT)+Xmxind_MT)];
173
                nv_MT = [];
174
                idx_column_X = [(idx_column_MT(end)/T_num)-(length(idx_column_MT)/T_num)+1:(idx_column_MT(end)/T_num)];
175
                for idx_t = 1:length(ys)
176
                    idx_column = idx_column_MT(idx_t:T_num:end);
177
                    nv_MT = [nv_MT; Xs{idx_t}(:,idx_column_X)*v_MT(idx_column,d_MT)];
178
                end
179
            else
180
                nv_MT = [];
181
                for idx_t = 1:length(ys)
182
                    idx_column = [idx_t:T_num:lenght(v_MT)];
183
                    nv_MT = [nv_MT; Xs{idx_t}*v_MT(idx_column,end)];
184
                end
185
            end
186
        else
187
            theta_MT = [];
188
            y_all = [];
189
            for idx_t = 1:length(ys)
190
                theta_MT = [theta_MT;(ys{idx_t} - Xs{idx_t}*Sol_MT(:,idx_t,Lambda_ind(i-1)))*rlambdap];
191
                y_all = [y_all; ys{idx_t}];
192
            end
193
            nv_MT = y_all*rlambdap-theta_MT;
194
        end
195
        
196
        % ----- estimate the possible region of the dual optimum at lambdac
197
        nv_MT = nv_MT/norm(nv_MT);
198
        %rv = y*rlambdac-theta;
199
        rv_MT = [];
200
        for idx_t = 1:length(ys)
201
            rv_MT = [rv_MT;ys{idx_t}*rlambdac];
202
        end
203
        rv_MT = rv_MT-theta_MT;
204
        
205
        Prv_MT = rv_MT - (nv_MT'*rv_MT)*nv_MT;
206
        o_MT = theta_MT + 0.5*Prv_MT;
207
        r_MT = 0.5*norm(Prv_MT);
208
        
209
        
210
        
211
        % ----- screening by MLFre, remove the ith feature if T(i)=1 ---- %
212
        c_MT = zeros(p*T_num,1);
213
        idx_row = (0:size(Xs{1},2)-1)*length(Xs)+1;
214
        for idx_t = 1:length(ys)
215
            c_MT(idx_row) = Xs{idx_t}'*o_MT((idx_t-1)*length(ys{1})+1:idx_t*length(ys{1}));
216
            idx_row = idx_row+1;
217
        end
218
        [u_MT, v_MT] = Hierarchical_Projection( c_MT, ind_MT, nnind_MT, Gind_MT );
219
        v2_MT = v_MT.*v_MT;
220
        for l = 1:d_MT % compute norm of v for each node and arrange them based on the tree structure
221
            if l==1&&ind_MT(1,1)==-1
222
                vnormTree_MT(:,l) = abs(v_MT(:,l));
223
            else
224
                G_MT = Gind_MT{1,l};
225
                vnormTree_MT(:,l)=G_MT'*sqrt(G_MT*v2_MT(:,l));
226
            end
227
        end
228
        csDifwv_MT = cumsum(weightTree_MT-vnormTree_MT);
229
        
230
        T_MT = false(p*T_num,1); % identify non-leaf inactive nodes
231
        for l = d_MT:-1:2
232
            Tl_MT = ~T_MT; % find the indices of the remaining features
233
            % case 1
234
            Tc_MT = false(p*T_num,1);
235
            Tc_MT(Tl_MT) = vnormTree_MT(Tl_MT,l)>tol0;
236
            if nnz(Tc_MT)>0
237
                s_MT(Tc_MT) = vnormTree_MT(Tc_MT,l)+r_MT*XgnormTree_MT(Tc_MT,l);
238
            end
239
            
240
            % case 2 & 3
241
            if nnz(Tc_MT)<nnz(Tl_MT) % if not all remaining nodes in level l fall in case 1
242
                Tcc_MT = false(p*T_num,1);
243
                Tcc_MT(Tl_MT) = ~Tc_MT(Tl_MT);
244
                lind_MT = nnind_MT(l)+1:nnind_MT(l+1);
245
                G_MT = Gind_MT{1,l};
246
                Tn_MT = G_MT*Tcc_MT==(ind_MT(2,lind_MT)-ind_MT(1,lind_MT)+1)';
247
                indl_MT = ind_MT(:,lind_MT);
248
                indlr_MT = indl_MT(:,Tn_MT);
249
                for n = 1:nnz(Tn_MT)
250
                    minn_MT(n)=min(csDifwv_MT(indlr_MT(1,n):indlr_MT(2,n),l-1));
251
                end
252
                sdist_MT = (G_MT(Tn_MT,:))'*minn_MT(1:nnz(Tn_MT));
253
                s_MT(Tcc_MT) = max(0,r_MT*XgnormTree_MT(Tcc_MT,l)-sdist_MT(Tcc_MT));
254
            end
255
            
256
            ind_zf_MT(Tl_MT,l,Lambda_ind(i))=s_MT(Tl_MT)<weightTree_MT(Tl_MT,l);
257
            T_MT = T_MT|ind_zf_MT(:,l,Lambda_ind(i));
258
        end
259
        
260
        Tl_MT = ~T_MT; % identify inactive leaf nodes
261
        if ind_MT(1,1)==-1
262
            s_MT(Tl_MT) = abs(c_MT(Tl_MT))+r_MT*Xnorm_MT(Tl_MT);
263
            ind_zf_MT(Tl_MT,1,Lambda_ind(i))=s_MT(Tl_MT)<ind_MT(3,1);
264
        else
265
            G_MT = Gind_MT{1,1};
266
            lind_MT = nnind_MT(1)+1:nnind_MT(2);
267
            Tn_MT = G_MT*Tl_MT == (ind_MT(2,lind_MT)-ind_MT(1,lind_MT)+1)';
268
            c2_MT(Tl_MT) = c_MT(Tl_MT).*c_MT(Tl_MT);
269
            cnorm_MT = (G_MT(Tn_MT,:))'*sqrt(G_MT(Tn_MT,:)*c2_MT);
270
            s_MT(Tl_MT) = cnorm_MT(Tl_MT)+r_MT*XgnormTree_MT(Tl_MT,1);
271
            ind_zf_MT(Tl_MT,1,Lambda_ind(i))=s_MT(Tl_MT)<weightTree_MT(Tl_MT,1);
272
        end
273
        T_MT = T_MT|ind_zf_MT(:,1,Lambda_ind(i));
274
        
275
        nT_MT = ~T_MT;
276
        
277
        
278
        %Xr = X(:,nT);
279
        nT_MT_X = nT_MT(1:T_num:end);
280
        for idx_t = 1:T_num
281
            Xrs{idx_t} = Xs{idx_t}(:,nT_MT_X);
282
        end
283
        
284
        
285
        
286
        if lambdap == lambda_max
287
            opts.x0 = zeros(nnz(nT_MT_X)*T_num,1);
288
        else
289
            x0_Matrix = Sol_MT(nT_MT_X,:,Lambda_ind(i-1));
290
            x0_temp =zeros(size(x0_Matrix,1)*T_num,1);
291
            idx_row = [0:size(x0_Matrix,1)-1]*T_num+1;
292
            for idx_t = 1:T_num
293
                x0_temp(idx_row) = x0_Matrix(:,idx_t);
294
                idx_row = idx_row+1;
295
            end
296
            opts.x0 = x0_temp;
297
        end
298
        
299
        % ------------ construct the reduced tree ---------------
300
        Tind_MT = false(ng_MT,1);
301
        nnlr_MT = zeros(1,d_MT+1);
302
        nnlr_MT(end) = 1;
303
        if ind_MT(1,1)==-1
304
            j=2;
305
            nnlr_MT(1)=nnz(nT_MT);
306
        else
307
            j=1;
308
        end
309
        for l = j:d_MT
310
            lind_MT = nnind_MT(l)+1:nnind_MT(l+1);
311
            Tind_MT(lind_MT) = Gind_MT{1,l}*T_MT==(ind_MT(2,lind_MT)-ind_MT(1,lind_MT)+1)';
312
            nnlr_MT(l)=nnz(~Tind_MT(lind_MT));
313
        end
314
        if ind_MT(1,1)==-1
315
            nnindr_MT=[0,1,cumsum(nnlr_MT(2:end))+1];
316
        else
317
            nnindr_MT=[0,cumsum(nnlr_MT)];
318
        end
319
        indr_MT = ind_MT(:,~Tind_MT);
320
        mapinde_MT = cumsum(nT_MT);
321
        mapinds_MT = nnz(nT_MT)+1-cumsum(nT_MT,'reverse');
322
        for l=j:d_MT+1
323
            lind_MT = nnindr_MT(l)+1:nnindr_MT(l+1);
324
            oind1_MT = indr_MT(1,lind_MT);
325
            oind2_MT = indr_MT(2,lind_MT);
326
            indr_MT(1,lind_MT) = mapinds_MT(oind1_MT);
327
            indr_MT(2,lind_MT) = mapinde_MT(oind2_MT);
328
        end
329
        
330
        
331
        opts.ind_MT = indr_MT;
332
        % --- solve the STM problem on the reduced data matrix -- %
333
        if opts.tFlag == 2
334
            opts.tol = funVal(Lambda_ind(i));
335
        end
336
        tscreen_MT(Lambda_ind(i)) = toc(starts_screening);
337
        
338
        starts = tic;
339
        [x1, ~, ~]= tree_LeastR_MT(Xrs, ys, lambdac, opts);
340
        tsolver_MT(Lambda_ind(i)) = toc(starts);
341
        
342
        nT_MT_temp = nT_MT(1:T_num:end);
343
        idx_row= [1:T_num:length(x1)];
344
        for idx_t = 1:T_num
345
            Sol_MT(nT_MT_temp,idx_t,Lambda_ind(i)) = x1(idx_row);
346
            idx_row = idx_row +1;
347
        end
348
    end
349
    lambdap = lambdac;
350
    rlambdap = rlambdac;
351
end
352
353
end
354