a b/network/SparseMatrix.java
1
/******************************************************************************
2
 *  Compilation:  javac SparseMatrix.java
3
 *  Execution:    java SparseMatrix
4
 *  
5
 *  A sparse, square matrix, implementing using two arrays of sparse
6
 *  vectors, one representation for the rows and one for the columns.
7
 *
8
 *  For matrix-matrix product, we might also want to store the
9
 *  column representation.
10
 *
11
 ******************************************************************************/
12
13
14
package network;
15
16
import java.util.HashSet;
17
import java.util.Set;
18
19
public class SparseMatrix {
20
    private final int N;           // N-by-N matrix
21
    private SparseVector[] rows;   // the rows, each row is a sparse vector
22
23
    // initialize an N-by-N matrix of all 0s
24
    public SparseMatrix(int N) {
25
        this.N  = N;
26
        rows = new SparseVector[N];
27
        for (int i = 0; i < N; i++)
28
            rows[i] = new SparseVector(N);
29
    }
30
31
    public Set<Integer> getKey(int a) {
32
        return rows[a].getKeys();
33
    }
34
    
35
    public SparseVector getNeibor(int a) {
36
        return rows[a];
37
    }
38
    
39
    // put A[i][j] = value
40
    public void put(int i, int j, double value) {
41
        if (i < 0 || i >= N) throw new RuntimeException("Illegal index");
42
        if (j < 0 || j >= N) throw new RuntimeException("Illegal index");
43
        rows[i].put(j, value);
44
    }
45
46
    // return A[i][j]
47
    public double get(int i, int j) {
48
        if (i < 0 || i >= N) throw new RuntimeException("Illegal index");
49
        if (j < 0 || j >= N) throw new RuntimeException("Illegal index");
50
        return rows[i].get(j);
51
    }
52
53
    // return the number of nonzero entries (not the most efficient implementation)
54
    public int nnz() { 
55
        int sum = 0;
56
        for (int i = 0; i < N; i++)
57
            sum += rows[i].nnz();
58
        return sum;
59
    }
60
    
61
    public int size() { 
62
        return N;
63
    }
64
65
    // return the matrix-vector product b = Ax
66
    public SparseVector times(SparseVector x) {
67
        SparseMatrix A = this;
68
        if (N != x.size()) throw new RuntimeException("Dimensions disagree");
69
        SparseVector b = new SparseVector(N);
70
        for (int i = 0; i < N; i++)
71
            b.put(i, A.rows[i].dot(x));
72
        return b;
73
    }
74
75
    // return C = A + B
76
    public SparseMatrix plus(SparseMatrix B) {
77
        SparseMatrix A = this;
78
        if (A.N != B.N) throw new RuntimeException("Dimensions disagree");
79
        SparseMatrix C = new SparseMatrix(N);
80
        for (int i = 0; i < N; i++)
81
            C.rows[i] = A.rows[i].plus(B.rows[i]);
82
        return C;
83
    }
84
85
86
    // return a string representation
87
    public String toString() {
88
        StringBuilder s = new StringBuilder();
89
        s.append("N = " + N + ", nonzeros = " + nnz() + "\n");
90
        for (int i = 0; i < N; i++) {
91
            s.append(i + ": " + rows[i] + "\n");
92
        }
93
        return s.toString();
94
    }
95
96
97
    // test client
98
    public static void main(String[] args) {
99
        SparseMatrix A = new SparseMatrix(5);
100
        SparseVector x = new SparseVector(5);
101
        A.put(0, 0, 1.0);
102
        A.put(1, 1, 1.0);
103
        A.put(2, 2, 1.0);
104
        A.put(3, 3, 1.0);
105
        A.put(4, 4, 1.0);
106
        A.put(2, 4, 0.3);
107
        x.put(0, 0.75);
108
        x.put(2, 0.11);
109
        System.out.println("x     : " + x);
110
        System.out.println("A     : " + A);
111
        System.out.println("Ax    : " + A.times(x));
112
        System.out.println("A + A : " + A.plus(A));
113
    }
114
115
}