Diff of /MLFre/gen_sync1.m [000000] .. [d8e26d]

Switch to unified view

a b/MLFre/gen_sync1.m
1
function [ X, y, beta, ind ] = gen_sync1( p, N, d, nns, ratio )
2
%% generate synthetic data with zero pair-wise correlation
3
4
nl = p./nns; % the lenght of nl is d+1; n[end] = 1 as we only has 
5
                      % one root node; nl[1] is the number of nodes at d
6
                      % layer; nl[2] is the number of nodes at d-1 layer
7
8
% construct sparse matrix to vectorize the computation
9
gind = zeros(d,p);
10
for l = 1:d
11
    gind(l,nns(l):nns(l):p) = 1;
12
    gind(l,:) = nl(l)+1-cumsum(gind(l,:),'reverse');
13
end
14
Gind = cell(1,d);
15
if nl(1)==p
16
    j=2;
17
else
18
    j=1;
19
end
20
for l=j:d
21
    Gind{l}=sparse(gind(l,:),1:p,ones(1,p),nl(l),p);
22
end
23
                      
24
% generate the data matrix
25
mu = zeros(1,p);
26
SIGMA = speye(p);
27
X = mvnrnd(mu,SIGMA,N);
28
29
% generate the response
30
T = true(p,1);
31
for l = 1:d
32
    if l==1&&nl(1)==p
33
        Tf = unifrnd(0,1,[p,1])<=ratio(l);
34
    else
35
        Tl = unifrnd(0,1,[nl(l),1])<=ratio(l);
36
        Tf = logical((Gind{l})'*Tl);
37
    end
38
    T = T&Tf;
39
end
40
41
42
beta = zeros(p,1);
43
beta(T) = normrnd(0,1,[nnz(T),1]);
44
45
y = X*beta + 0.01*normrnd(0,1,[N,1]);  
46
47
% construct ind
48
49
if nl(1) == p % if the nodes at the first layer only has one feature
50
    cumnl = [0,1,cumsum(nl(2:end))+1];
51
    ind = zeros(3,cumnl(end));
52
    ind(:,1)=[-1, -1, 1]';
53
    ind(:,end)=[1,p,0]';
54
    for i = 2:d
55
        ind(1,cumnl(i)+1:cumnl(i+1)) = 1:nns(i):p;
56
        ind(2,cumnl(i)+1:cumnl(i+1)) = ind(1,cumnl(i)+1:cumnl(i+1))+nns(i)-1;
57
        ind(3,cumnl(i)+1:cumnl(i+1)) = sqrt(nns(i));
58
    end
59
else
60
    cumnl = [0,cumsum(nl)];
61
    ind = zeros(3,cumnl(end));
62
    ind(:,end)=[1,p,0]';
63
    for i = 1:d
64
        ind(1,cumnl(i)+1:cumnl(i+1)) = 1:nns(i):p;
65
        ind(2,cumnl(i)+1:cumnl(i+1)) = ind(1,cumnl(i)+1:cumnl(i+1)) + nns(i);
66
        ind(3,cumnl(i)+1:cumnl(i+1)) = sqrt(nns(i));
67
    end
68
end
69
70
end
71