--- a +++ b/STM_HFS.m @@ -0,0 +1,354 @@ +function [Sol_MT] = STM_HFS(Xs, ys, Lambda, opts) +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +%% Implementation of the sequential HFS rule for STM +%% input: +% Xs: +% Xs{i} stores the data matrix of the i-th task, each column corresponds to a feature +% each row corresponds to a data instance +% +% ys: +% ys{i} strores the response vector of the i-th task +% +% Lambda: +% the parameter values of lambda +% +% opts: +% settings for the solver +%% output: +% Sol: +% the solution; Sol(:,:,i) stores the the +% solution for the ith values in Lambda +% +%% For any problem, please contact Weizhong Zhang (zhangweizhongzju@gmail.com) +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +%% + +% -------------------------- pass parameters ---------------------------- % +p = size(Xs{1}, 2); +T_num = length(Xs); % number of tasks +npar = length(Lambda); % number of parameter values of lambda +ind = opts.ind; +ind_MT = TreeTransform(ind, T_num); +%clear('ind'); +opts.init=1; % starting from a zero point +if opts.tFlag==2 + funVal = opts.funVal; +end + + + +% --------------------- recover tree structure for Multi Task-------------------------- % +eind_MT = find(ind_MT(2,:) == p*T_num); +if ind_MT(1,1) == -1 % find the depth of the tree + d_MT = length(eind_MT); + nnl_MT = [p*T_num,eind_MT(1)-1,diff(eind_MT)]; % number of nodes per layer + nnind_MT = [0,1,eind_MT]; % ind(1,nnind1(i)+1:nnind(i+1)) stores the node from the ith layer +else + d_MT = length(eind_MT)-1; + nnl_MT = [eind_MT(1),diff(eind_MT)]; + nnind_MT = [0,eind_MT]; +end + + +% --------------------- initialize the output --------------------------- % +Sol_MT = zeros(p,T_num,npar); +ind_zf_MT = false(p*T_num,d_MT,npar); +tsolver_MT = zeros(1,npar);%should be changed +tscreen_MT = zeros(1,npar); + +% ------------------- compute the effective region of lambda ------------ % +Xsys_MT = Xsys_MT_cal(Xs, ys); +%lambda_max=findLambdaMax(X'*y, p, ind, size(ind,2)); % why? +lambda_max=findLambdaMax(Xsys_MT, p*T_num, ind_MT, size(ind_MT,2)); +%lambda_max1=findLambdaMax(Xsys_MT, p*T_num, ind, size(ind,2)); +if opts.rFlag == 1 + Lambda = Lambda * lambda_max; + opts.rFlag = 0; +end +[Lambdav,Lambda_ind] = sort(Lambda,'descend'); + + + + +% --------- compute the norm of each feature and each submatrix --------- +Xnorm_MT = zeros(size(Xs{1},2)*T_num,1); +idx_row = [0:size(Xs{1},2)-1]*T_num+1; +for idx_t = 1:T_num + Xnorm_MT(idx_row) = (sqrt(sum(Xs{idx_t}.^2,1)))'; + idx_row = idx_row +1; +end +ng_MT = size(ind_MT,2); +Xgnorm_MT = zeros(1,ng_MT); +gind_MT = zeros(d_MT,p*T_num); +if ind_MT(1,1)==-1 + j = 2; + l = 2; +else + l = 1; + j = 1; +end +k = 1; +for i = j : ng_MT-1 + for idx_t = 1: length(Xs) + if(idx_t ==1) + Xgnorm_MT(i) = norm(Xs{idx_t}(:,((ind_MT(1,i)-1)/T_num+1):(ind_MT(2,i)/T_num))); + else + Xgnorm_MT(i)= max(Xgnorm_MT(i),norm(Xs{idx_t}(:,((ind_MT(1,i)-1)/T_num+1):(ind_MT(2,i)/T_num)))); + end + end + gind_MT(l,ind_MT(1,i):ind_MT(2,i)) = k; + k = k + 1; + if ind_MT(2,i)==p*T_num + k = 1; + l = l + 1; + end +end + +% ------- construct sparse matrix to vectorize the computation ---------- +Gind_MT = cell(1,d_MT); +if ind_MT(1,1) == -1 + j = 2; +else + j = 1; +end +for i = j:d_MT + Gind_MT{1,i} = sparse(gind_MT(i,:),1:p*T_num,ones(1,p*T_num),nnl_MT(i),p*T_num); +end + + + +% --------------- put Xgnorm and weights in tree structure --------------- +XgnormTree_MT = zeros(p*T_num,d_MT); +weightTree_MT = zeros(p*T_num,d_MT); +for i = 1:d_MT + if ind_MT(1,1)==-1&&i==1 + XgnormTree_MT(:,i) = Xnorm_MT; + weightTree_MT(:,i) = ind_MT(3,1); + else + G_MT = Gind_MT{1,i}; + XgnormTree_MT(:,i) = G_MT'*(Xgnorm_MT(nnind_MT(i)+1:nnind_MT(i+1)))'; + weightTree_MT(:,i) = G_MT'*(ind_MT(3,nnind_MT(i)+1:nnind_MT(i+1)))'; + end +end + + + + + +% ----------- solve STM sequentially via HFS ------------------ +opts.rFlag = 0; % the input parameters are their true values + +s_MT = zeros(p*T_num,1); +c2_MT = zeros(p*T_num,1); +minn_MT = zeros(p*T_num,1); + +rLambdav = 1./Lambdav; +lambdap = Lambdav(1); +rlambdap = rLambdav(1); +vnormTree_MT = zeros(p*T_num,d_MT); +tol0 = 1e-12; +for i = 1:npar + + %fprintf('in HFS step: %d\n',i); + lambdac = Lambdav(1,i); + rlambdac = rLambdav(1,i); + if lambdac>=lambda_max + ind_zf_MT(:,:,Lambda_ind(i)) = true; + else + starts_screening =tic; + if lambdap==lambda_max + theta_MT = []; + for idx_t = 1: length(ys) + theta_MT = [theta_MT; ys{idx_t}*rlambdap]; + end + + z_MT = Xsys_MT*rlambdap; + [u_MT, v_MT] = Hierarchical_Projection( z_MT, ind_MT, nnind_MT, Gind_MT ); + + if ind_MT(3,end)==0 + weightd_MT = (ind_MT(3,nnind_MT(d_MT)+1:nnind_MT(d_MT+1)))'; + [~,Xmxind_MT] = min(abs(Gind_MT{1,d_MT}*(v_MT(:,d_MT).*v_MT(:,d_MT))-weightd_MT.*weightd_MT)); + idx_column_MT = [ind_MT(1,nnind_MT(d_MT)+Xmxind_MT):ind_MT(2,nnind_MT(d_MT)+Xmxind_MT)]; + nv_MT = []; + idx_column_X = [(idx_column_MT(end)/T_num)-(length(idx_column_MT)/T_num)+1:(idx_column_MT(end)/T_num)]; + for idx_t = 1:length(ys) + idx_column = idx_column_MT(idx_t:T_num:end); + nv_MT = [nv_MT; Xs{idx_t}(:,idx_column_X)*v_MT(idx_column,d_MT)]; + end + else + nv_MT = []; + for idx_t = 1:length(ys) + idx_column = [idx_t:T_num:lenght(v_MT)]; + nv_MT = [nv_MT; Xs{idx_t}*v_MT(idx_column,end)]; + end + end + else + theta_MT = []; + y_all = []; + for idx_t = 1:length(ys) + theta_MT = [theta_MT;(ys{idx_t} - Xs{idx_t}*Sol_MT(:,idx_t,Lambda_ind(i-1)))*rlambdap]; + y_all = [y_all; ys{idx_t}]; + end + nv_MT = y_all*rlambdap-theta_MT; + end + + % ----- estimate the possible region of the dual optimum at lambdac + nv_MT = nv_MT/norm(nv_MT); + %rv = y*rlambdac-theta; + rv_MT = []; + for idx_t = 1:length(ys) + rv_MT = [rv_MT;ys{idx_t}*rlambdac]; + end + rv_MT = rv_MT-theta_MT; + + Prv_MT = rv_MT - (nv_MT'*rv_MT)*nv_MT; + o_MT = theta_MT + 0.5*Prv_MT; + r_MT = 0.5*norm(Prv_MT); + + + + % ----- screening by MLFre, remove the ith feature if T(i)=1 ---- % + c_MT = zeros(p*T_num,1); + idx_row = (0:size(Xs{1},2)-1)*length(Xs)+1; + for idx_t = 1:length(ys) + c_MT(idx_row) = Xs{idx_t}'*o_MT((idx_t-1)*length(ys{1})+1:idx_t*length(ys{1})); + idx_row = idx_row+1; + end + [u_MT, v_MT] = Hierarchical_Projection( c_MT, ind_MT, nnind_MT, Gind_MT ); + v2_MT = v_MT.*v_MT; + for l = 1:d_MT % compute norm of v for each node and arrange them based on the tree structure + if l==1&&ind_MT(1,1)==-1 + vnormTree_MT(:,l) = abs(v_MT(:,l)); + else + G_MT = Gind_MT{1,l}; + vnormTree_MT(:,l)=G_MT'*sqrt(G_MT*v2_MT(:,l)); + end + end + csDifwv_MT = cumsum(weightTree_MT-vnormTree_MT); + + T_MT = false(p*T_num,1); % identify non-leaf inactive nodes + for l = d_MT:-1:2 + Tl_MT = ~T_MT; % find the indices of the remaining features + % case 1 + Tc_MT = false(p*T_num,1); + Tc_MT(Tl_MT) = vnormTree_MT(Tl_MT,l)>tol0; + if nnz(Tc_MT)>0 + s_MT(Tc_MT) = vnormTree_MT(Tc_MT,l)+r_MT*XgnormTree_MT(Tc_MT,l); + end + + % case 2 & 3 + if nnz(Tc_MT)<nnz(Tl_MT) % if not all remaining nodes in level l fall in case 1 + Tcc_MT = false(p*T_num,1); + Tcc_MT(Tl_MT) = ~Tc_MT(Tl_MT); + lind_MT = nnind_MT(l)+1:nnind_MT(l+1); + G_MT = Gind_MT{1,l}; + Tn_MT = G_MT*Tcc_MT==(ind_MT(2,lind_MT)-ind_MT(1,lind_MT)+1)'; + indl_MT = ind_MT(:,lind_MT); + indlr_MT = indl_MT(:,Tn_MT); + for n = 1:nnz(Tn_MT) + minn_MT(n)=min(csDifwv_MT(indlr_MT(1,n):indlr_MT(2,n),l-1)); + end + sdist_MT = (G_MT(Tn_MT,:))'*minn_MT(1:nnz(Tn_MT)); + s_MT(Tcc_MT) = max(0,r_MT*XgnormTree_MT(Tcc_MT,l)-sdist_MT(Tcc_MT)); + end + + ind_zf_MT(Tl_MT,l,Lambda_ind(i))=s_MT(Tl_MT)<weightTree_MT(Tl_MT,l); + T_MT = T_MT|ind_zf_MT(:,l,Lambda_ind(i)); + end + + Tl_MT = ~T_MT; % identify inactive leaf nodes + if ind_MT(1,1)==-1 + s_MT(Tl_MT) = abs(c_MT(Tl_MT))+r_MT*Xnorm_MT(Tl_MT); + ind_zf_MT(Tl_MT,1,Lambda_ind(i))=s_MT(Tl_MT)<ind_MT(3,1); + else + G_MT = Gind_MT{1,1}; + lind_MT = nnind_MT(1)+1:nnind_MT(2); + Tn_MT = G_MT*Tl_MT == (ind_MT(2,lind_MT)-ind_MT(1,lind_MT)+1)'; + c2_MT(Tl_MT) = c_MT(Tl_MT).*c_MT(Tl_MT); + cnorm_MT = (G_MT(Tn_MT,:))'*sqrt(G_MT(Tn_MT,:)*c2_MT); + s_MT(Tl_MT) = cnorm_MT(Tl_MT)+r_MT*XgnormTree_MT(Tl_MT,1); + ind_zf_MT(Tl_MT,1,Lambda_ind(i))=s_MT(Tl_MT)<weightTree_MT(Tl_MT,1); + end + T_MT = T_MT|ind_zf_MT(:,1,Lambda_ind(i)); + + nT_MT = ~T_MT; + + + %Xr = X(:,nT); + nT_MT_X = nT_MT(1:T_num:end); + for idx_t = 1:T_num + Xrs{idx_t} = Xs{idx_t}(:,nT_MT_X); + end + + + + if lambdap == lambda_max + opts.x0 = zeros(nnz(nT_MT_X)*T_num,1); + else + x0_Matrix = Sol_MT(nT_MT_X,:,Lambda_ind(i-1)); + x0_temp =zeros(size(x0_Matrix,1)*T_num,1); + idx_row = [0:size(x0_Matrix,1)-1]*T_num+1; + for idx_t = 1:T_num + x0_temp(idx_row) = x0_Matrix(:,idx_t); + idx_row = idx_row+1; + end + opts.x0 = x0_temp; + end + + % ------------ construct the reduced tree --------------- + Tind_MT = false(ng_MT,1); + nnlr_MT = zeros(1,d_MT+1); + nnlr_MT(end) = 1; + if ind_MT(1,1)==-1 + j=2; + nnlr_MT(1)=nnz(nT_MT); + else + j=1; + end + for l = j:d_MT + lind_MT = nnind_MT(l)+1:nnind_MT(l+1); + Tind_MT(lind_MT) = Gind_MT{1,l}*T_MT==(ind_MT(2,lind_MT)-ind_MT(1,lind_MT)+1)'; + nnlr_MT(l)=nnz(~Tind_MT(lind_MT)); + end + if ind_MT(1,1)==-1 + nnindr_MT=[0,1,cumsum(nnlr_MT(2:end))+1]; + else + nnindr_MT=[0,cumsum(nnlr_MT)]; + end + indr_MT = ind_MT(:,~Tind_MT); + mapinde_MT = cumsum(nT_MT); + mapinds_MT = nnz(nT_MT)+1-cumsum(nT_MT,'reverse'); + for l=j:d_MT+1 + lind_MT = nnindr_MT(l)+1:nnindr_MT(l+1); + oind1_MT = indr_MT(1,lind_MT); + oind2_MT = indr_MT(2,lind_MT); + indr_MT(1,lind_MT) = mapinds_MT(oind1_MT); + indr_MT(2,lind_MT) = mapinde_MT(oind2_MT); + end + + + opts.ind_MT = indr_MT; + % --- solve the STM problem on the reduced data matrix -- % + if opts.tFlag == 2 + opts.tol = funVal(Lambda_ind(i)); + end + tscreen_MT(Lambda_ind(i)) = toc(starts_screening); + + starts = tic; + [x1, ~, ~]= tree_LeastR_MT(Xrs, ys, lambdac, opts); + tsolver_MT(Lambda_ind(i)) = toc(starts); + + nT_MT_temp = nT_MT(1:T_num:end); + idx_row= [1:T_num:length(x1)]; + for idx_t = 1:T_num + Sol_MT(nT_MT_temp,idx_t,Lambda_ind(i)) = x1(idx_row); + idx_row = idx_row +1; + end + end + lambdap = lambdac; + rlambdap = rlambdac; +end + +end +