function [Sol_MT] = STM_HFS(Xs, ys, Lambda, opts)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Implementation of the sequential HFS rule for STM
%% input:
% Xs:
% Xs{i} stores the data matrix of the i-th task, each column corresponds to a feature
% each row corresponds to a data instance
%
% ys:
% ys{i} strores the response vector of the i-th task
%
% Lambda:
% the parameter values of lambda
%
% opts:
% settings for the solver
%% output:
% Sol:
% the solution; Sol(:,:,i) stores the the
% solution for the ith values in Lambda
%
%% For any problem, please contact Weizhong Zhang (zhangweizhongzju@gmail.com)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%
% -------------------------- pass parameters ---------------------------- %
p = size(Xs{1}, 2);
T_num = length(Xs); % number of tasks
npar = length(Lambda); % number of parameter values of lambda
ind = opts.ind;
ind_MT = TreeTransform(ind, T_num);
%clear('ind');
opts.init=1; % starting from a zero point
if opts.tFlag==2
funVal = opts.funVal;
end
% --------------------- recover tree structure for Multi Task-------------------------- %
eind_MT = find(ind_MT(2,:) == p*T_num);
if ind_MT(1,1) == -1 % find the depth of the tree
d_MT = length(eind_MT);
nnl_MT = [p*T_num,eind_MT(1)-1,diff(eind_MT)]; % number of nodes per layer
nnind_MT = [0,1,eind_MT]; % ind(1,nnind1(i)+1:nnind(i+1)) stores the node from the ith layer
else
d_MT = length(eind_MT)-1;
nnl_MT = [eind_MT(1),diff(eind_MT)];
nnind_MT = [0,eind_MT];
end
% --------------------- initialize the output --------------------------- %
Sol_MT = zeros(p,T_num,npar);
ind_zf_MT = false(p*T_num,d_MT,npar);
tsolver_MT = zeros(1,npar);%should be changed
tscreen_MT = zeros(1,npar);
% ------------------- compute the effective region of lambda ------------ %
Xsys_MT = Xsys_MT_cal(Xs, ys);
%lambda_max=findLambdaMax(X'*y, p, ind, size(ind,2)); % why?
lambda_max=findLambdaMax(Xsys_MT, p*T_num, ind_MT, size(ind_MT,2));
%lambda_max1=findLambdaMax(Xsys_MT, p*T_num, ind, size(ind,2));
if opts.rFlag == 1
Lambda = Lambda * lambda_max;
opts.rFlag = 0;
end
[Lambdav,Lambda_ind] = sort(Lambda,'descend');
% --------- compute the norm of each feature and each submatrix ---------
Xnorm_MT = zeros(size(Xs{1},2)*T_num,1);
idx_row = [0:size(Xs{1},2)-1]*T_num+1;
for idx_t = 1:T_num
Xnorm_MT(idx_row) = (sqrt(sum(Xs{idx_t}.^2,1)))';
idx_row = idx_row +1;
end
ng_MT = size(ind_MT,2);
Xgnorm_MT = zeros(1,ng_MT);
gind_MT = zeros(d_MT,p*T_num);
if ind_MT(1,1)==-1
j = 2;
l = 2;
else
l = 1;
j = 1;
end
k = 1;
for i = j : ng_MT-1
for idx_t = 1: length(Xs)
if(idx_t ==1)
Xgnorm_MT(i) = norm(Xs{idx_t}(:,((ind_MT(1,i)-1)/T_num+1):(ind_MT(2,i)/T_num)));
else
Xgnorm_MT(i)= max(Xgnorm_MT(i),norm(Xs{idx_t}(:,((ind_MT(1,i)-1)/T_num+1):(ind_MT(2,i)/T_num))));
end
end
gind_MT(l,ind_MT(1,i):ind_MT(2,i)) = k;
k = k + 1;
if ind_MT(2,i)==p*T_num
k = 1;
l = l + 1;
end
end
% ------- construct sparse matrix to vectorize the computation ----------
Gind_MT = cell(1,d_MT);
if ind_MT(1,1) == -1
j = 2;
else
j = 1;
end
for i = j:d_MT
Gind_MT{1,i} = sparse(gind_MT(i,:),1:p*T_num,ones(1,p*T_num),nnl_MT(i),p*T_num);
end
% --------------- put Xgnorm and weights in tree structure ---------------
XgnormTree_MT = zeros(p*T_num,d_MT);
weightTree_MT = zeros(p*T_num,d_MT);
for i = 1:d_MT
if ind_MT(1,1)==-1&&i==1
XgnormTree_MT(:,i) = Xnorm_MT;
weightTree_MT(:,i) = ind_MT(3,1);
else
G_MT = Gind_MT{1,i};
XgnormTree_MT(:,i) = G_MT'*(Xgnorm_MT(nnind_MT(i)+1:nnind_MT(i+1)))';
weightTree_MT(:,i) = G_MT'*(ind_MT(3,nnind_MT(i)+1:nnind_MT(i+1)))';
end
end
% ----------- solve STM sequentially via HFS ------------------
opts.rFlag = 0; % the input parameters are their true values
s_MT = zeros(p*T_num,1);
c2_MT = zeros(p*T_num,1);
minn_MT = zeros(p*T_num,1);
rLambdav = 1./Lambdav;
lambdap = Lambdav(1);
rlambdap = rLambdav(1);
vnormTree_MT = zeros(p*T_num,d_MT);
tol0 = 1e-12;
for i = 1:npar
%fprintf('in HFS step: %d\n',i);
lambdac = Lambdav(1,i);
rlambdac = rLambdav(1,i);
if lambdac>=lambda_max
ind_zf_MT(:,:,Lambda_ind(i)) = true;
else
starts_screening =tic;
if lambdap==lambda_max
theta_MT = [];
for idx_t = 1: length(ys)
theta_MT = [theta_MT; ys{idx_t}*rlambdap];
end
z_MT = Xsys_MT*rlambdap;
[u_MT, v_MT] = Hierarchical_Projection( z_MT, ind_MT, nnind_MT, Gind_MT );
if ind_MT(3,end)==0
weightd_MT = (ind_MT(3,nnind_MT(d_MT)+1:nnind_MT(d_MT+1)))';
[~,Xmxind_MT] = min(abs(Gind_MT{1,d_MT}*(v_MT(:,d_MT).*v_MT(:,d_MT))-weightd_MT.*weightd_MT));
idx_column_MT = [ind_MT(1,nnind_MT(d_MT)+Xmxind_MT):ind_MT(2,nnind_MT(d_MT)+Xmxind_MT)];
nv_MT = [];
idx_column_X = [(idx_column_MT(end)/T_num)-(length(idx_column_MT)/T_num)+1:(idx_column_MT(end)/T_num)];
for idx_t = 1:length(ys)
idx_column = idx_column_MT(idx_t:T_num:end);
nv_MT = [nv_MT; Xs{idx_t}(:,idx_column_X)*v_MT(idx_column,d_MT)];
end
else
nv_MT = [];
for idx_t = 1:length(ys)
idx_column = [idx_t:T_num:lenght(v_MT)];
nv_MT = [nv_MT; Xs{idx_t}*v_MT(idx_column,end)];
end
end
else
theta_MT = [];
y_all = [];
for idx_t = 1:length(ys)
theta_MT = [theta_MT;(ys{idx_t} - Xs{idx_t}*Sol_MT(:,idx_t,Lambda_ind(i-1)))*rlambdap];
y_all = [y_all; ys{idx_t}];
end
nv_MT = y_all*rlambdap-theta_MT;
end
% ----- estimate the possible region of the dual optimum at lambdac
nv_MT = nv_MT/norm(nv_MT);
%rv = y*rlambdac-theta;
rv_MT = [];
for idx_t = 1:length(ys)
rv_MT = [rv_MT;ys{idx_t}*rlambdac];
end
rv_MT = rv_MT-theta_MT;
Prv_MT = rv_MT - (nv_MT'*rv_MT)*nv_MT;
o_MT = theta_MT + 0.5*Prv_MT;
r_MT = 0.5*norm(Prv_MT);
% ----- screening by MLFre, remove the ith feature if T(i)=1 ---- %
c_MT = zeros(p*T_num,1);
idx_row = (0:size(Xs{1},2)-1)*length(Xs)+1;
for idx_t = 1:length(ys)
c_MT(idx_row) = Xs{idx_t}'*o_MT((idx_t-1)*length(ys{1})+1:idx_t*length(ys{1}));
idx_row = idx_row+1;
end
[u_MT, v_MT] = Hierarchical_Projection( c_MT, ind_MT, nnind_MT, Gind_MT );
v2_MT = v_MT.*v_MT;
for l = 1:d_MT % compute norm of v for each node and arrange them based on the tree structure
if l==1&&ind_MT(1,1)==-1
vnormTree_MT(:,l) = abs(v_MT(:,l));
else
G_MT = Gind_MT{1,l};
vnormTree_MT(:,l)=G_MT'*sqrt(G_MT*v2_MT(:,l));
end
end
csDifwv_MT = cumsum(weightTree_MT-vnormTree_MT);
T_MT = false(p*T_num,1); % identify non-leaf inactive nodes
for l = d_MT:-1:2
Tl_MT = ~T_MT; % find the indices of the remaining features
% case 1
Tc_MT = false(p*T_num,1);
Tc_MT(Tl_MT) = vnormTree_MT(Tl_MT,l)>tol0;
if nnz(Tc_MT)>0
s_MT(Tc_MT) = vnormTree_MT(Tc_MT,l)+r_MT*XgnormTree_MT(Tc_MT,l);
end
% case 2 & 3
if nnz(Tc_MT)<nnz(Tl_MT) % if not all remaining nodes in level l fall in case 1
Tcc_MT = false(p*T_num,1);
Tcc_MT(Tl_MT) = ~Tc_MT(Tl_MT);
lind_MT = nnind_MT(l)+1:nnind_MT(l+1);
G_MT = Gind_MT{1,l};
Tn_MT = G_MT*Tcc_MT==(ind_MT(2,lind_MT)-ind_MT(1,lind_MT)+1)';
indl_MT = ind_MT(:,lind_MT);
indlr_MT = indl_MT(:,Tn_MT);
for n = 1:nnz(Tn_MT)
minn_MT(n)=min(csDifwv_MT(indlr_MT(1,n):indlr_MT(2,n),l-1));
end
sdist_MT = (G_MT(Tn_MT,:))'*minn_MT(1:nnz(Tn_MT));
s_MT(Tcc_MT) = max(0,r_MT*XgnormTree_MT(Tcc_MT,l)-sdist_MT(Tcc_MT));
end
ind_zf_MT(Tl_MT,l,Lambda_ind(i))=s_MT(Tl_MT)<weightTree_MT(Tl_MT,l);
T_MT = T_MT|ind_zf_MT(:,l,Lambda_ind(i));
end
Tl_MT = ~T_MT; % identify inactive leaf nodes
if ind_MT(1,1)==-1
s_MT(Tl_MT) = abs(c_MT(Tl_MT))+r_MT*Xnorm_MT(Tl_MT);
ind_zf_MT(Tl_MT,1,Lambda_ind(i))=s_MT(Tl_MT)<ind_MT(3,1);
else
G_MT = Gind_MT{1,1};
lind_MT = nnind_MT(1)+1:nnind_MT(2);
Tn_MT = G_MT*Tl_MT == (ind_MT(2,lind_MT)-ind_MT(1,lind_MT)+1)';
c2_MT(Tl_MT) = c_MT(Tl_MT).*c_MT(Tl_MT);
cnorm_MT = (G_MT(Tn_MT,:))'*sqrt(G_MT(Tn_MT,:)*c2_MT);
s_MT(Tl_MT) = cnorm_MT(Tl_MT)+r_MT*XgnormTree_MT(Tl_MT,1);
ind_zf_MT(Tl_MT,1,Lambda_ind(i))=s_MT(Tl_MT)<weightTree_MT(Tl_MT,1);
end
T_MT = T_MT|ind_zf_MT(:,1,Lambda_ind(i));
nT_MT = ~T_MT;
%Xr = X(:,nT);
nT_MT_X = nT_MT(1:T_num:end);
for idx_t = 1:T_num
Xrs{idx_t} = Xs{idx_t}(:,nT_MT_X);
end
if lambdap == lambda_max
opts.x0 = zeros(nnz(nT_MT_X)*T_num,1);
else
x0_Matrix = Sol_MT(nT_MT_X,:,Lambda_ind(i-1));
x0_temp =zeros(size(x0_Matrix,1)*T_num,1);
idx_row = [0:size(x0_Matrix,1)-1]*T_num+1;
for idx_t = 1:T_num
x0_temp(idx_row) = x0_Matrix(:,idx_t);
idx_row = idx_row+1;
end
opts.x0 = x0_temp;
end
% ------------ construct the reduced tree ---------------
Tind_MT = false(ng_MT,1);
nnlr_MT = zeros(1,d_MT+1);
nnlr_MT(end) = 1;
if ind_MT(1,1)==-1
j=2;
nnlr_MT(1)=nnz(nT_MT);
else
j=1;
end
for l = j:d_MT
lind_MT = nnind_MT(l)+1:nnind_MT(l+1);
Tind_MT(lind_MT) = Gind_MT{1,l}*T_MT==(ind_MT(2,lind_MT)-ind_MT(1,lind_MT)+1)';
nnlr_MT(l)=nnz(~Tind_MT(lind_MT));
end
if ind_MT(1,1)==-1
nnindr_MT=[0,1,cumsum(nnlr_MT(2:end))+1];
else
nnindr_MT=[0,cumsum(nnlr_MT)];
end
indr_MT = ind_MT(:,~Tind_MT);
mapinde_MT = cumsum(nT_MT);
mapinds_MT = nnz(nT_MT)+1-cumsum(nT_MT,'reverse');
for l=j:d_MT+1
lind_MT = nnindr_MT(l)+1:nnindr_MT(l+1);
oind1_MT = indr_MT(1,lind_MT);
oind2_MT = indr_MT(2,lind_MT);
indr_MT(1,lind_MT) = mapinds_MT(oind1_MT);
indr_MT(2,lind_MT) = mapinde_MT(oind2_MT);
end
opts.ind_MT = indr_MT;
% --- solve the STM problem on the reduced data matrix -- %
if opts.tFlag == 2
opts.tol = funVal(Lambda_ind(i));
end
tscreen_MT(Lambda_ind(i)) = toc(starts_screening);
starts = tic;
[x1, ~, ~]= tree_LeastR_MT(Xrs, ys, lambdac, opts);
tsolver_MT(Lambda_ind(i)) = toc(starts);
nT_MT_temp = nT_MT(1:T_num:end);
idx_row= [1:T_num:length(x1)];
for idx_t = 1:T_num
Sol_MT(nT_MT_temp,idx_t,Lambda_ind(i)) = x1(idx_row);
idx_row = idx_row +1;
end
end
lambdap = lambdac;
rlambdap = rlambdac;
end
end