[61e40d]: / network / SparseMatrix.java

Download this file

116 lines (97 with data), 3.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
/******************************************************************************
* Compilation: javac SparseMatrix.java
* Execution: java SparseMatrix
*
* A sparse, square matrix, implementing using two arrays of sparse
* vectors, one representation for the rows and one for the columns.
*
* For matrix-matrix product, we might also want to store the
* column representation.
*
******************************************************************************/
package network;
import java.util.HashSet;
import java.util.Set;
public class SparseMatrix {
private final int N; // N-by-N matrix
private SparseVector[] rows; // the rows, each row is a sparse vector
// initialize an N-by-N matrix of all 0s
public SparseMatrix(int N) {
this.N = N;
rows = new SparseVector[N];
for (int i = 0; i < N; i++)
rows[i] = new SparseVector(N);
}
public Set<Integer> getKey(int a) {
return rows[a].getKeys();
}
public SparseVector getNeibor(int a) {
return rows[a];
}
// put A[i][j] = value
public void put(int i, int j, double value) {
if (i < 0 || i >= N) throw new RuntimeException("Illegal index");
if (j < 0 || j >= N) throw new RuntimeException("Illegal index");
rows[i].put(j, value);
}
// return A[i][j]
public double get(int i, int j) {
if (i < 0 || i >= N) throw new RuntimeException("Illegal index");
if (j < 0 || j >= N) throw new RuntimeException("Illegal index");
return rows[i].get(j);
}
// return the number of nonzero entries (not the most efficient implementation)
public int nnz() {
int sum = 0;
for (int i = 0; i < N; i++)
sum += rows[i].nnz();
return sum;
}
public int size() {
return N;
}
// return the matrix-vector product b = Ax
public SparseVector times(SparseVector x) {
SparseMatrix A = this;
if (N != x.size()) throw new RuntimeException("Dimensions disagree");
SparseVector b = new SparseVector(N);
for (int i = 0; i < N; i++)
b.put(i, A.rows[i].dot(x));
return b;
}
// return C = A + B
public SparseMatrix plus(SparseMatrix B) {
SparseMatrix A = this;
if (A.N != B.N) throw new RuntimeException("Dimensions disagree");
SparseMatrix C = new SparseMatrix(N);
for (int i = 0; i < N; i++)
C.rows[i] = A.rows[i].plus(B.rows[i]);
return C;
}
// return a string representation
public String toString() {
StringBuilder s = new StringBuilder();
s.append("N = " + N + ", nonzeros = " + nnz() + "\n");
for (int i = 0; i < N; i++) {
s.append(i + ": " + rows[i] + "\n");
}
return s.toString();
}
// test client
public static void main(String[] args) {
SparseMatrix A = new SparseMatrix(5);
SparseVector x = new SparseVector(5);
A.put(0, 0, 1.0);
A.put(1, 1, 1.0);
A.put(2, 2, 1.0);
A.put(3, 3, 1.0);
A.put(4, 4, 1.0);
A.put(2, 4, 0.3);
x.put(0, 0.75);
x.put(2, 0.11);
System.out.println("x : " + x);
System.out.println("A : " + A);
System.out.println("Ax : " + A.times(x));
System.out.println("A + A : " + A.plus(A));
}
}