[61e40d]: / network / SparseVector.java

Download this file

139 lines (115 with data), 4.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
/******************************************************************************
* Compilation: javac SparseVector.java
* Execution: java SparseVector
*
* A sparse vector, implementing using a symbol table.
*
* [Not clear we need the instance variable N except for error checking.]
*
******************************************************************************/
package network;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.Set;
public class SparseVector implements Iterable<Integer> {
private final int N; // length
private ST<Integer, Double> st; // the vector, represented by index-value pairs
// initialize the all 0s vector of length N
public SparseVector(int N) {
this.N = N;
this.st = new ST<Integer, Double>();
}
public Set<Integer> getKeys() {
return st.getKeys();
}
@Override
public Iterator<Integer> iterator() {
// TODO Auto-generated method stub
return st.getKeys().iterator();
}
public Iterable<Integer> getKeysIter() {
return st.keys();
}
// put st[i] = value
public void put(int i, double value) {
if (i < 0 || i >= N) throw new RuntimeException("Illegal index");
if (value == 0.0) st.delete(i);
else st.put(i, value);
}
// return st[i]
public double get(int i) {
if (i < 0 || i >= N) throw new RuntimeException("Illegal index");
if (st.contains(i)) return st.get(i);
else return 0.0;
}
// return the number of nonzero entries
public int nnz() {
return st.size();
}
// return the size of the vector
public int size() {
return N;
}
// return the dot product of this vector a with b
public double dot(SparseVector b) {
SparseVector a = this;
if (a.N != b.N) throw new RuntimeException("Vector lengths disagree");
double sum = 0.0;
// iterate over the vector with the fewest nonzeros
if (a.st.size() <= b.st.size()) {
for (int i : a.st.keys())
if (b.st.contains(i)) sum += a.get(i) * b.get(i);
}
else {
for (int i : b.st.keys())
if (a.st.contains(i)) sum += a.get(i) * b.get(i);
}
return sum;
}
// return the 2-norm
public double norm() {
SparseVector a = this;
return Math.sqrt(a.dot(a));
}
// return alpha * a
public SparseVector scale(double alpha) {
SparseVector a = this;
SparseVector c = new SparseVector(N);
for (int i : a.st.keys()) c.put(i, alpha * a.get(i));
return c;
}
// return a + b
public SparseVector plus(SparseVector b) {
SparseVector a = this;
if (a.N != b.N) throw new RuntimeException("Vector lengths disagree");
SparseVector c = new SparseVector(N);
for (int i : a.st.keys()) c.put(i, a.get(i)); // c = a
for (int i : b.st.keys()) c.put(i, b.get(i) + c.get(i)); // c = c + b
return c;
}
// return a string representation
public String toString() {
StringBuilder s = new StringBuilder();
for (int i : st.keys()) {
s.append("(" + i + ", " + st.get(i) + ") ");
}
return s.toString();
}
// test client
public static void main(String[] args) {
SparseVector a = new SparseVector(10);
SparseVector b = new SparseVector(10);
a.put(3, 0.50);
a.put(9, 0.75);
a.put(6, 0.11);
a.put(6, 0.00);
b.put(3, 0.60);
b.put(4, 0.90);
System.out.println("a = " + a);
System.out.println("b = " + b);
System.out.println("a dot b = " + a.dot(b));
System.out.println("a + b = " + a.plus(b));
}
}