Diff of /Demo/MTFL.m [000000] .. [edb3de]

Switch to unified view

a b/Demo/MTFL.m
1
function [W,Omega]=MTERL(data,label,lambda,gamma,nu)
2
%the Omega sloved in the demo can  refer to: 
3
%Yu Zhang and Dit-Yan Yeung. A Convex Formulation for Learning Task Relationships in Multi-Task Learning. 
4
%In: Proceedings of the 26th Conference on Uncertainty in Artificial Intelligence (UAI), 2010.
5
6
m=length(label); 
7
d=size(data{1,1},2);
8
W=zeros(d,m); 
9
V=zeros(d,m);
10
11
[W,Omega]=myAPG(W,V,data,label,lambda,gamma,nu);
12
13
end
14
function [W,Omega]=myAPG(W_initial,V,data,label,lambda,gamma,nu)
15
m=length(label);
16
d=size(data{1,1},2);
17
epsilon=10^(-8);
18
max_iteration=1000;
19
t=0;  alpha = 1; W=W_initial;
20
Omega=eye(m)/m;
21
[data_O,label_O,task_index,ins_num]=PreprocessMTData(data,label);
22
23
n=size(data_O,1);
24
insIndex=cell(1,m);
25
ins_indicator=zeros(m,n);
26
for i=1:m
27
    insIndex{i}=sort(find(task_index==i));
28
    ins_indicator(i,insIndex{i})=1;
29
end
30
threshold=10^(-12);
31
model.alpha=zeros(1,n);
32
model.b=zeros(1,m);kertype='linear';kerpar=0;
33
Km=CalculateKernelMatrix(data_O,kertype,kerpar);
34
m_Cor=real(Omega/(lambda*Omega+gamma*eye(m)));
35
36
for iter=1:max_iteration
37
    old_model=model;
38
    old_Omega=Omega;
39
    MTKm=Km.*m_Cor(task_index,task_index);
40
    model=MTRL_RR(MTKm,label_O,task_index,insIndex,ins_num);
41
    clear MTKm;
42
    temp=m_Cor(:,task_index)*diag(model.alpha);
43
    temp=temp*Km*temp';
44
    [eigVector,eigValue]=eig(temp+epsilon*eye(m));
45
    clear temp;
46
    eigValue=sqrt(abs(diag(eigValue)));
47
    eigValue=eigValue/sum(eigValue);
48
    Omega=eigVector*diag(eigValue)*eigVector';
49
    m_Cor=eigVector*diag(eigValue./(lambda*eigValue+gamma))*eigVector';
50
    clear eigVector eigValue;
51
    if norm(model.alpha-old_model.alpha,2)<=threshold*n&&norm(model.b-old_model.b,2)<=threshold*m&&norm(Omega-old_Omega,'fro')<=threshold*m*m
52
        clear old_model old_Omega;
53
        break;
54
    end
55
    clear old_model old_Omega; 
56
    U=(1-alpha)*W+alpha*V;
57
    Lu=100000;%should be set
58
    G=[];
59
    for i=1:d
60
        Vi=min(1,max(-1,U(i,:)/nu));
61
        G=[G;sum(abs(U(i,:)))*Vi];       
62
    end
63
    tmp_o = U*Omega^(-1);
64
    for i=1:m
65
        V(:,i)=V(:,i)-1/(alpha*Lu)*( data{1,i}'* (data{1,i}* U(:,i)-label{1,i}') + lambda*G(:,i) +gamma* tmp_o(:,i));
66
    end  
67
    W=(1-alpha)*W+alpha*V;
68
    alpha=2/(t+1);
69
    t=t+1;    
70
    F_wh=0;  
71
    for i=1:m
72
        F_wh = F_wh+norm(label{1,i}-W(:,i)'*data{1,i}',2)^2/(2*m);
73
    end
74
    for i=1:d      
75
        F_wh =F_wh+lambda*sum( abs(W(i,:)) )^2/2;
76
    end
77
    F_wh = F_wh + gamma * trace(W*Omega^(-1)*W')/2;
78
    fValue(iter,1)=F_wh;    
79
    if (iter>10 && (abs(fValue(iter,1)-fValue(iter-1,1))/abs(fValue(iter-1,1))<epsilon))   
80
        break;
81
    end   
82
end
83
end