Diff of /MLFre/DPC/SLEP/altra.h [000000] .. [d8e26d]

Switch to unified view

a b/MLFre/DPC/SLEP/altra.h
1
#include "mex.h"
2
#include <stdio.h>
3
#include <math.h>
4
#include <string.h>
5
6
7
/*
8
 * Important Notice: September 20, 2010
9
 *
10
 * In this head file, we assume that the features in the tree strucutre
11
 * are well ordered. That is to say, the indices of the left nodes is always less
12
 * than the right nodes. Ideally, this can be achieved by reordering the features.
13
 *
14
 * The advantage of this ordered features is that, we donot need to use an explicit
15
 * variable for recording the indices.
16
 *
17
 * To deal with the more general case when the features might not be well ordered,
18
 * we provide the functions in the head file "general_altra.h". Compared with the files in this head file,
19
 * we need an additional parameter G, which contains the indices of the nodes.
20
 *
21
 *
22
 */
23
24
/*
25
 * -------------------------------------------------------------------
26
 *                       Functions and parameter
27
 * -------------------------------------------------------------------
28
 *
29
 * altra solves the following problem
30
 *
31
 * 1/2 \|x-v\|^2 + \sum \lambda_i \|x_{G_i}\|,
32
 *
33
 * where x and v are of dimension n,
34
 *       \lambda_i >=0, and G_i's follow the tree structure
35
 *
36
 * It is implemented in Matlab as follows:
37
 *
38
 * x=altra(v, n, ind, nodes);
39
 *
40
 * ind is a 3 x nodes matrix.
41
 *       Each column corresponds to a node.
42
 *
43
 *       The first element of each column is the starting index,
44
 *       the second element of each column is the ending index
45
 *       the third element of each column corrreponds to \lambbda_i.
46
 *
47
 * -------------------------------------------------------------------
48
 *                       Notices:
49
 * -------------------------------------------------------------------
50
 *
51
 * 1. The nodes in the parameter "ind" should be given in the 
52
 *    either
53
 *           the postordering of depth-first traversal
54
 *    or 
55
 *           the reverse breadth-first traversal.
56
 *
57
 * 2. When each elements of x are penalized via the same L1 
58
 *    (equivalent to the L2 norm) parameter, one can simplify the input
59
 *    by specifying 
60
 *           the "first" column of ind as (-1, -1, lambda)
61
 *
62
 *    In this case, we treat it as a single "super" node. Thus in the value
63
 *    nodes, we only count it once.
64
 *
65
 * 3. The values in "ind" are in [1,n].
66
 *
67
 * 4. The third element of each column should be positive. The program does
68
 *    not check the validity of the parameter. 
69
 *
70
 *    It is still valid to use the zero regularization parameter.
71
 *    In this case, the program does not change the values of 
72
 *    correponding indices.
73
 *    
74
 *
75
 * -------------------------------------------------------------------
76
 *                       History:
77
 * -------------------------------------------------------------------
78
 *
79
 * Composed by Jun Liu on April 20, 2010
80
 *
81
 * For any question or suggestion, please email j.liu@asu.edu.
82
 *
83
 */
84
85
86
void altra(double *x, double *v, int n, double *ind, int nodes){
87
    
88
    int i, j, m;
89
    double lambda,twoNorm, ratio;
90
    
91
    /*
92
     * test whether the first node is special
93
     */
94
    if ((int) ind[0]==-1){
95
        
96
        /*
97
         *Recheck whether ind[1] equals to zero
98
         */
99
        if ((int) ind[1]!=-1){
100
            printf("\n Error! \n Check ind");
101
            exit(1);
102
        }        
103
        
104
        lambda=ind[2];
105
        
106
        for(j=0;j<n;j++){
107
            if (v[j]>lambda)
108
                x[j]=v[j]-lambda;
109
            else
110
                if (v[j]<-lambda)
111
                    x[j]=v[j]+lambda;
112
                else
113
                    x[j]=0;
114
        }
115
        
116
        i=1;
117
    }
118
    else{
119
        memcpy(x, v, sizeof(double) * n);
120
        i=0;
121
    }
122
            
123
    /*
124
     * sequentially process each node
125
     *
126
     */
127
    for(;i < nodes; i++){
128
        /*
129
         * compute the L2 norm of this group         
130
         */
131
        twoNorm=0;
132
        for(j=(int) ind[3*i]-1;j< (int) ind[3*i+1];j++)
133
            twoNorm += x[j] * x[j];        
134
        twoNorm=sqrt(twoNorm);
135
        
136
        lambda=ind[3*i+2];
137
        if (twoNorm>lambda){
138
            ratio=(twoNorm-lambda)/twoNorm;
139
            
140
            /*
141
             * shrinkage this group by ratio
142
             */
143
            for(j=(int) ind[3*i]-1;j<(int) ind[3*i+1];j++)
144
                x[j]*=ratio;            
145
        }
146
        else{
147
            /*
148
             * threshold this group to zero
149
             */
150
            for(j=(int) ind[3*i]-1;j<(int) ind[3*i+1];j++)
151
                x[j]=0;
152
        }
153
    }
154
}
155
156
157
158
/*
159
 * altra_mt is a generalization of altra to the 
160
 * 
161
 * multi-task learning scenario (or equivalently the multi-class case)
162
 *
163
 * altra_mt(X, V, n, k, ind, nodes);
164
 *
165
 * It applies altra for each row (1xk) of X and V
166
 *
167
 */
168
169
170
void altra_mt(double *X, double *V, int n, int k, double *ind, int nodes){
171
    int i, j;
172
    
173
    double *x=(double *)malloc(sizeof(double)*k);
174
    double *v=(double *)malloc(sizeof(double)*k);
175
    
176
    for (i=0;i<n;i++){
177
        /*
178
         * copy a row of V to v
179
         *         
180
         */
181
        for(j=0;j<k;j++)
182
            v[j]=V[j*n + i];
183
        
184
        altra(x, v, k, ind, nodes);
185
        
186
        /*
187
         * copy the solution to X         
188
         */        
189
        for(j=0;j<k;j++)
190
            X[j*n+i]=x[j];
191
    }
192
    
193
    free(x);
194
    free(v);
195
}
196
197
198
199
200
/*
201
 * compute
202
 *  lambda2_max=computeLambda2Max(x,n,ind,nodes);
203
 *
204
 * compute the 2 norm of each group, which is divided by the ind(3,:),
205
 * then the maximum value is returned
206
 */
207
    /*
208
     *This function does not consider the case ind={[-1, -1, 100]',...}
209
     *
210
     *This functions is not used currently.
211
     */
212
213
void computeLambda2Max(double *lambda2_max, double *x, int n, double *ind, int nodes){
214
    int i, j, m;
215
    double lambda,twoNorm;
216
    
217
    *lambda2_max=0;
218
    
219
    for(i=0;i < nodes; i++){
220
        /*
221
         * compute the L2 norm of this group         
222
         */
223
        twoNorm=0;
224
        for(j=(int) ind[3*i]-1;j< (int) ind[3*i+1];j++)
225
            twoNorm += x[j] * x[j];        
226
        twoNorm=sqrt(twoNorm);
227
        
228
        twoNorm=twoNorm/ind[3*i+2];
229
        
230
        if (twoNorm >*lambda2_max )
231
            *lambda2_max=twoNorm;        
232
    }
233
}
234
235
/*
236
 * -------------------------------------------------------------------
237
 *                       Function and parameter
238
 * -------------------------------------------------------------------
239
 *
240
 * treeNorm compute
241
 *
242
 *        \sum \lambda_i \|x_{G_i}\|,
243
 *
244
 * where x is of dimension n,
245
 *       \lambda_i >=0, and G_i's follow the tree structure
246
 *
247
 * The file is implemented in the following in Matlab:
248
 *
249
 * tree_norm=treeNorm(x, n, ind,nodes);
250
 */
251
252
253
void treeNorm(double *tree_norm, double *x, int n, double *ind, int nodes){
254
    
255
    int i, j, m;
256
    double twoNorm, lambda;
257
    
258
    *tree_norm=0;
259
    
260
    /*
261
     * test whether the first node is special
262
     */
263
    if ((int) ind[0]==-1){
264
        
265
        /*
266
         *Recheck whether ind[1] equals to zero
267
         */
268
        if ((int) ind[1]!=-1){
269
            printf("\n Error! \n Check ind");
270
            exit(1);
271
        }        
272
        
273
        lambda=ind[2];
274
        
275
        for(j=0;j<n;j++){
276
            *tree_norm+=fabs(x[j]);
277
        }
278
        
279
        *tree_norm=*tree_norm * lambda;
280
        
281
        i=1;
282
    }
283
    else{
284
        i=0;
285
    }
286
            
287
    /*
288
     * sequentially process each node
289
     *
290
     */
291
    for(;i < nodes; i++){
292
        /*
293
         * compute the L2 norm of this group         
294
         */
295
        twoNorm=0;
296
        for(j=(int) ind[3*i]-1;j< (int) ind[3*i+1];j++)
297
            twoNorm += x[j] * x[j];        
298
        twoNorm=sqrt(twoNorm);
299
        
300
        lambda=ind[3*i+2];
301
        
302
        *tree_norm=*tree_norm + lambda*twoNorm;
303
    }
304
}
305
306
307
/*
308
 * -------------------------------------------------------------------
309
 *                       Function and parameter
310
 * -------------------------------------------------------------------
311
 *
312
 * findLambdaMax compute
313
 * 
314
 * the lambda_{max} that achieves a zero solution for
315
 *
316
 *     min  1/2 \|x-v\|^2 +  \lambda_{\max} * \sum  w_i \|x_{G_i}\|,
317
 *
318
 * where x is of dimension n,
319
 *       w_i >=0, and G_i's follow the tree structure
320
 *
321
 * The file is implemented in the following in Matlab:
322
 *
323
 * lambdaMax=findLambdaMax(v, n, ind,nodes);
324
 */
325
326
void findLambdaMax(double *lambdaMax, double *v, int n, double *ind, int nodes){
327
 
328
    int i, j;
329
    double lambda=0,squaredWeight=0, lambda1,lambda2;
330
    double *x=(double *)malloc(sizeof(double)*n);
331
    double *ind2=(double *)malloc(sizeof(double)*nodes*3);
332
    int num=0;
333
       
334
    for(i=0;i<n;i++){
335
        lambda+=v[i]*v[i];
336
    }
337
    
338
    if ( (int)ind[0]==-1 )
339
        squaredWeight=n*ind[2]*ind[2];
340
    else
341
        squaredWeight=ind[2]*ind[2];
342
    
343
    for (i=1;i<nodes;i++){
344
        squaredWeight+=ind[3*i+2]*ind[3*i+2];
345
    }
346
    
347
    /* set lambda to an initial guess
348
     */
349
    lambda=sqrt(lambda/squaredWeight);
350
    
351
    /*
352
    printf("\n\n   lambda=%2.5f",lambda);
353
    */
354
    
355
    /*
356
     *copy ind to ind2,
357
     *and scale the weight 3*i+2
358
     */
359
    for(i=0;i<nodes;i++){
360
        ind2[3*i]=ind[3*i];
361
        ind2[3*i+1]=ind[3*i+1];
362
        ind2[3*i+2]=ind[3*i+2]*lambda;
363
    }
364
    
365
    /* test whether the solution is zero or not
366
     */
367
    altra(x, v, n, ind2, nodes);    
368
    for(i=0;i<n;i++){
369
        if (x[i]!=0)
370
            break;
371
    }
372
    
373
    if (i>=n) {
374
        /*x is a zero vector*/
375
        lambda2=lambda;
376
        lambda1=lambda;
377
        
378
        num=0;
379
        
380
        while(1){
381
            num++;
382
            
383
            lambda2=lambda;
384
            lambda1=lambda1/2;
385
            /* update ind2
386
             */
387
            for(i=0;i<nodes;i++){
388
                ind2[3*i+2]=ind[3*i+2]*lambda1;
389
            }
390
            
391
            /* compute and test whether x is zero
392
             */
393
            altra(x, v, n, ind2, nodes);
394
            for(i=0;i<n;i++){
395
                if (x[i]!=0)
396
                    break;
397
            }
398
            
399
            if (i<n){
400
                break;
401
                /*x is not zero
402
                 *we have found lambda1
403
                 */
404
            }
405
        }
406
407
    }
408
    else{
409
        /*x is a non-zero vector*/
410
        lambda2=lambda;
411
        lambda1=lambda;
412
        
413
        num=0;
414
        while(1){
415
            num++;            
416
            
417
            lambda1=lambda2;
418
            lambda2=lambda2*2;
419
            /* update ind2
420
             */
421
            for(i=0;i<nodes;i++){
422
                ind2[3*i+2]=ind[3*i+2]*lambda2;
423
            }
424
            
425
            /* compute and test whether x is zero
426
             */
427
            altra(x, v, n, ind2, nodes);
428
            for(i=0;i<n;i++){
429
                if (x[i]!=0)
430
                    break;
431
            }
432
            
433
            if (i>=n){
434
                break;
435
                /*x is a zero vector
436
                 *we have found lambda2
437
                 */
438
            }
439
        }
440
    }    
441
    
442
    /*
443
    printf("\n num=%d, lambda1=%2.5f, lambda2=%2.5f",num, lambda1,lambda2);
444
    */
445
    
446
    while ( fabs(lambda2-lambda1) > lambda2 * 1e-10 ){
447
        
448
        num++;
449
        
450
        lambda=(lambda1+lambda2)/2;
451
        
452
        /* update ind2
453
         */
454
        for(i=0;i<nodes;i++){
455
            ind2[3*i+2]=ind[3*i+2]*lambda;
456
        }
457
        
458
        /* compute and test whether x is zero
459
         */
460
        altra(x, v, n, ind2, nodes);
461
        for(i=0;i<n;i++){
462
            if (x[i]!=0)
463
                break;
464
        }
465
        
466
        if (i>=n){
467
            lambda2=lambda;
468
        }
469
        else{
470
            lambda1=lambda;
471
        }
472
        
473
       /*
474
        printf("\n lambda1=%2.5f, lambda2=%2.5f",lambda1,lambda2);
475
        */
476
    }
477
    
478
   /*
479
    printf("\n num=%d",num);
480
    
481
    printf("   lambda1=%2.5f, lambda2=%2.5f",lambda1,lambda2);
482
     
483
    */
484
    
485
    *lambdaMax=lambda2;
486
    
487
    free(x);
488
    free(ind2);
489
}
490
491
492
/*
493
 * findLambdaMax_mt is a generalization of findLambdaMax to the 
494
 * 
495
 * multi-task learning scenario (or equivalently the multi-class case)
496
 *
497
 * lambdaMax=findLambdaMax_mt(X, V, n, k, ind, nodes);
498
 *
499
 * It applies findLambdaMax for each row (1xk) of X and V
500
 *
501
 */
502
503
504
void findLambdaMax_mt(double *lambdaMax, double *V, int n, int k, double *ind, int nodes){
505
    int i, j;
506
    
507
    double *v=(double *)malloc(sizeof(double)*k);
508
    double lambda;
509
    
510
    *lambdaMax=0;
511
    
512
    for (i=0;i<n;i++){
513
        /*
514
         * copy a row of V to v
515
         *         
516
         */
517
        for(j=0;j<k;j++)
518
            v[j]=V[j*n + i];
519
        
520
        findLambdaMax(&lambda, v, k, ind, nodes);
521
        
522
        /*
523
        printf("\n   lambda=%5.2f",lambda);        
524
         */
525
        
526
        if (lambda>*lambdaMax)
527
            *lambdaMax=lambda;
528
    }
529
    
530
    /*
531
    printf("\n *lambdaMax=%5.2f",*lambdaMax);
532
     */
533
    
534
    free(v);
535
}
536