Diff of /network/SparseMatrix.java [000000] .. [4fba4e]

Switch to side-by-side view

--- a
+++ b/network/SparseMatrix.java
@@ -0,0 +1,115 @@
+/******************************************************************************
+ *  Compilation:  javac SparseMatrix.java
+ *  Execution:    java SparseMatrix
+ *  
+ *  A sparse, square matrix, implementing using two arrays of sparse
+ *  vectors, one representation for the rows and one for the columns.
+ *
+ *  For matrix-matrix product, we might also want to store the
+ *  column representation.
+ *
+ ******************************************************************************/
+
+
+package network;
+
+import java.util.HashSet;
+import java.util.Set;
+
+public class SparseMatrix {
+    private final int N;           // N-by-N matrix
+    private SparseVector[] rows;   // the rows, each row is a sparse vector
+
+    // initialize an N-by-N matrix of all 0s
+    public SparseMatrix(int N) {
+        this.N  = N;
+        rows = new SparseVector[N];
+        for (int i = 0; i < N; i++)
+            rows[i] = new SparseVector(N);
+    }
+
+    public Set<Integer> getKey(int a) {
+    	return rows[a].getKeys();
+    }
+    
+    public SparseVector getNeibor(int a) {
+    	return rows[a];
+    }
+    
+    // put A[i][j] = value
+    public void put(int i, int j, double value) {
+        if (i < 0 || i >= N) throw new RuntimeException("Illegal index");
+        if (j < 0 || j >= N) throw new RuntimeException("Illegal index");
+        rows[i].put(j, value);
+    }
+
+    // return A[i][j]
+    public double get(int i, int j) {
+        if (i < 0 || i >= N) throw new RuntimeException("Illegal index");
+        if (j < 0 || j >= N) throw new RuntimeException("Illegal index");
+        return rows[i].get(j);
+    }
+
+    // return the number of nonzero entries (not the most efficient implementation)
+    public int nnz() { 
+        int sum = 0;
+        for (int i = 0; i < N; i++)
+            sum += rows[i].nnz();
+        return sum;
+    }
+    
+    public int size() { 
+        return N;
+    }
+
+    // return the matrix-vector product b = Ax
+    public SparseVector times(SparseVector x) {
+        SparseMatrix A = this;
+        if (N != x.size()) throw new RuntimeException("Dimensions disagree");
+        SparseVector b = new SparseVector(N);
+        for (int i = 0; i < N; i++)
+            b.put(i, A.rows[i].dot(x));
+        return b;
+    }
+
+    // return C = A + B
+    public SparseMatrix plus(SparseMatrix B) {
+        SparseMatrix A = this;
+        if (A.N != B.N) throw new RuntimeException("Dimensions disagree");
+        SparseMatrix C = new SparseMatrix(N);
+        for (int i = 0; i < N; i++)
+            C.rows[i] = A.rows[i].plus(B.rows[i]);
+        return C;
+    }
+
+
+    // return a string representation
+    public String toString() {
+        StringBuilder s = new StringBuilder();
+        s.append("N = " + N + ", nonzeros = " + nnz() + "\n");
+        for (int i = 0; i < N; i++) {
+            s.append(i + ": " + rows[i] + "\n");
+        }
+        return s.toString();
+    }
+
+
+    // test client
+    public static void main(String[] args) {
+        SparseMatrix A = new SparseMatrix(5);
+        SparseVector x = new SparseVector(5);
+        A.put(0, 0, 1.0);
+        A.put(1, 1, 1.0);
+        A.put(2, 2, 1.0);
+        A.put(3, 3, 1.0);
+        A.put(4, 4, 1.0);
+        A.put(2, 4, 0.3);
+        x.put(0, 0.75);
+        x.put(2, 0.11);
+        System.out.println("x     : " + x);
+        System.out.println("A     : " + A);
+        System.out.println("Ax    : " + A.times(x));
+        System.out.println("A + A : " + A.plus(A));
+    }
+
+}