a b/Demo_RR/maLRR.m
1
function [Z,Ez,Ew,W,Wi] = maLRR(T,S,Dim,Maxiter,alaph,beta)
2
% maLRR--Multi-Center Adaptation Framework with Low-Rank Representation
3
% Input:
4
%       T: Target domain data
5
%       S: Source domain data
6
% Output:
7
%       Z: Low rank representation of the source domain in target domain
8
%       Ez: error matrix of representation
9
%       Ew: error matrix of project
10
%       W: the project matrix of different source data
11
%       Ws: the shared project matrix
12
%  The code requires a GPU to run
13
14
%The parameters need to be tuned
15
mu = 1e-5;tol = 1e-8;max_mu = 1e6;rho =1.2;iter = 0;
16
[d,m] = size(T);M = length(S);Z = cell(1,M);Wi = cell(1,M);
17
W = eye(Dim,d);Ew = cell(1,M);Ez = cell(1,M);F = cell(1,M);Y1 = cell(1,M);
18
Y2 = cell(1,M);Y3 = zeros(Dim,d);Y4 = cell(1,M);Y1_tmp = cell(1,M);
19
Y2_tmp = cell(1,M);Y3_tmp =  zeros(Dim,d);Y4_tmp = cell(1,M);stopC = zeros(1,M);J = zeros(d);
20
for i=1:M
21
    Z{1,i} = zeros(m,size(S{1,i},2));
22
    Wi{1,i} = zeros(Dim,d);Ew{1,i} = zeros(Dim,d);
23
    Ez{1,i} = zeros(Dim,size(S{1,i},2));
24
    F{1,i} = zeros(m,size(S{1,i},2));
25
    Y1{1,i} = zeros(Dim,size(S{1,i},2));
26
    Y2{1,i} = zeros(Dim,d);
27
    Y4{1,i} = zeros(m,size(S{1,i},2));
28
    Y1_tmp{1,i} =zeros(Dim,size(S{1,i},2));
29
    Y2_tmp{1,i} = zeros(Dim,d);
30
    Y4_tmp{1,i} = zeros(m,size(S{1,i},2));
31
end
32
while iter < Maxiter
33
    iter
34
    iter = iter+1;
35
    for i=1:M
36
        temp = Z{1,i} + Y4{1,i}/mu;
37
        gpu_tmp = gpuArray(temp);
38
        [U_Temp,sigma_Temp,V_Temp] = svd(gpu_tmp,'econ');
39
        gpu_U_Temp = gather(U_Temp);
40
        gpu_sigma = gather(sigma_Temp);
41
        gpu_V_Temp = gather(V_Temp);
42
        sigma = diag(gpu_sigma);
43
        svp = length(find(sigma>1/mu));
44
        if svp>=1
45
            sigma = sigma(1:svp)-1/mu;
46
        else
47
            svp = 1;
48
            sigma = 0;
49
        end
50
        F{1,i} = gpu_U_Temp(:,1:svp)*diag(sigma)*gpu_V_Temp(:,1:svp)';
51
    end
52
    for i=1:M
53
        Wpt1 = S{1,i}*S{1,i}'+eye(d);
54
        Wpt2 = (W*T*Z{1,i}+Ez{1,i})*S{1,i}'+W+Ew{1,i}-(Y1{1,i}*S{1,i}'+Y2{1,i})/mu;
55
        temp = Wpt2/ Wpt1;
56
        Wi{1,i} = temp;
57
    end
58
    for i=1:M
59
        Zpt1 = T'*(W'*W)*T+eye(m);
60
        Zpt2 = (T'*W'*Y1{1,i}-Y4{1,i})/mu + F{1,i}+T'*W'*(Wi{1,i}*S{1,i}-Ez{1,i});
61
        temp = Zpt1\Zpt2;
62
        Z{1,i} = temp;
63
    end
64
    temp = W + Y3/mu;
65
    gpu_tmp = gpuArray(temp);
66
    [U_Temp,sigma_Temp,V_Temp] = svd(gpu_tmp,'econ');
67
    gpu_U_Temp = gather(U_Temp);
68
    gpu_sigma = gather(sigma_Temp);
69
    gpu_V_Temp = gather(V_Temp);
70
    sigma = diag(gpu_sigma);
71
    svp = length(find(sigma>1/mu));
72
    if svp>=1
73
        sigma = sigma(1:svp)-1/mu;
74
    else
75
        svp = 1;
76
        sigma = 0;
77
    end
78
    J = gpu_U_Temp(:,1:svp)*diag(sigma)*gpu_V_Temp(:,1:svp)';
79
    for i=1:M
80
        temp = Wi{1,i}*S{1,i}-W*T*Z{1,i}+ Y1{1,i}/mu;
81
        Ez{1,i} = max(0,temp - alaph/mu)+min(0,temp + alaph/mu);
82
    end
83
    for i=1:M
84
        temp = Wi{1,i}-W+Y2{1,i}/mu;
85
        Ew{1,i} = max(0,temp - beta/mu)+min(0,temp + beta/mu);
86
    end
87
    temp = zeros(d);
88
    for i=1:M
89
        temp = temp + T*Z{1,i}*Z{1,i}'*T'+eye(d);
90
    end
91
    Wpt1 = temp + eye(d);
92
    temp = zeros(Dim,d);
93
    for i=1:M
94
        temp = temp + Y1{1,i}*Z{1,i}'*T'+Y2{1,i};
95
    end
96
    Wpt2 = (temp+ mu*J-Y3)/mu;
97
    temp = zeros(Dim,d);
98
    for i=1:M
99
        temp = temp + (Wi{1,i}*S{1,i}-Ez{1,i})*Z{1,i}'*T'+Wi{1,i}-Ew{1,i};
100
    end
101
    Wpt2 = Wpt2 + temp;
102
    W = Wpt2/Wpt1;
103
    W = orth(W);
104
    Y3_tmp = W-J;
105
    for i=1:M
106
        Y1_tmp{1,i} = Wi{1,i}*S{1,i}-W*T*Z{1,i}-Ez{1,i};
107
        Y2_tmp{1,i} = Wi{1,i}-W-Ew{1,i};
108
        Y4_tmp{1,i} = Z{1,i}-F{1,i};
109
        stopC(1,i) = max(max(max(abs(Y1_tmp{1,i}))),max(max(abs(Y2_tmp{1,i}))));
110
        stopC(1,i) = max(max(max(abs(Y3_tmp))),stopC(1,i));
111
        stopC(1,i) = max(max(max(abs(Y4_tmp{1,i}))),stopC(1,i));
112
    end
113
    if iter==1 || mod(iter,50)==0 || max(stopC)<tol
114
        disp(['iter ' num2str(iter) ',mu=' num2str(mu,'%2.1e') ...
115
            ',stopALM=' num2str(max(stopC),'%2.3e')]);
116
    end
117
    if max(stopC)<tol
118
        break;
119
    else
120
        Y3 = Y3 + mu*Y3_tmp;
121
        for i=1:M
122
            Y1{1,i} = Y1{1,i} + mu*Y1_tmp{1,i};
123
            Y2{1,i} = Y2{1,i} + mu*Y2_tmp{1,i};
124
            Y4{1,i} = Y4{1,i} + mu*Y4_tmp{1,i};
125
        end
126
        mu = min(max_mu,mu*rho);
127
    end
128
end