|
a |
|
b/network/SparseVector.java |
|
|
1 |
/****************************************************************************** |
|
|
2 |
* Compilation: javac SparseVector.java |
|
|
3 |
* Execution: java SparseVector |
|
|
4 |
* |
|
|
5 |
* A sparse vector, implementing using a symbol table. |
|
|
6 |
* |
|
|
7 |
* [Not clear we need the instance variable N except for error checking.] |
|
|
8 |
* |
|
|
9 |
******************************************************************************/ |
|
|
10 |
package network; |
|
|
11 |
|
|
|
12 |
import java.util.HashMap; |
|
|
13 |
import java.util.HashSet; |
|
|
14 |
import java.util.Hashtable; |
|
|
15 |
import java.util.Iterator; |
|
|
16 |
import java.util.Set; |
|
|
17 |
|
|
|
18 |
public class SparseVector implements Iterable<Integer> { |
|
|
19 |
private final int N; // length |
|
|
20 |
private ST<Integer, Double> st; // the vector, represented by index-value pairs |
|
|
21 |
|
|
|
22 |
// initialize the all 0s vector of length N |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
public SparseVector(int N) { |
|
|
26 |
this.N = N; |
|
|
27 |
this.st = new ST<Integer, Double>(); |
|
|
28 |
} |
|
|
29 |
|
|
|
30 |
public Set<Integer> getKeys() { |
|
|
31 |
return st.getKeys(); |
|
|
32 |
} |
|
|
33 |
|
|
|
34 |
@Override |
|
|
35 |
public Iterator<Integer> iterator() { |
|
|
36 |
// TODO Auto-generated method stub |
|
|
37 |
return st.getKeys().iterator(); |
|
|
38 |
} |
|
|
39 |
|
|
|
40 |
public Iterable<Integer> getKeysIter() { |
|
|
41 |
return st.keys(); |
|
|
42 |
} |
|
|
43 |
|
|
|
44 |
// put st[i] = value |
|
|
45 |
public void put(int i, double value) { |
|
|
46 |
if (i < 0 || i >= N) throw new RuntimeException("Illegal index"); |
|
|
47 |
if (value == 0.0) st.delete(i); |
|
|
48 |
else st.put(i, value); |
|
|
49 |
} |
|
|
50 |
|
|
|
51 |
// return st[i] |
|
|
52 |
public double get(int i) { |
|
|
53 |
if (i < 0 || i >= N) throw new RuntimeException("Illegal index"); |
|
|
54 |
if (st.contains(i)) return st.get(i); |
|
|
55 |
else return 0.0; |
|
|
56 |
} |
|
|
57 |
|
|
|
58 |
// return the number of nonzero entries |
|
|
59 |
public int nnz() { |
|
|
60 |
return st.size(); |
|
|
61 |
} |
|
|
62 |
|
|
|
63 |
// return the size of the vector |
|
|
64 |
public int size() { |
|
|
65 |
return N; |
|
|
66 |
} |
|
|
67 |
|
|
|
68 |
// return the dot product of this vector a with b |
|
|
69 |
public double dot(SparseVector b) { |
|
|
70 |
SparseVector a = this; |
|
|
71 |
if (a.N != b.N) throw new RuntimeException("Vector lengths disagree"); |
|
|
72 |
double sum = 0.0; |
|
|
73 |
|
|
|
74 |
// iterate over the vector with the fewest nonzeros |
|
|
75 |
if (a.st.size() <= b.st.size()) { |
|
|
76 |
for (int i : a.st.keys()) |
|
|
77 |
if (b.st.contains(i)) sum += a.get(i) * b.get(i); |
|
|
78 |
} |
|
|
79 |
else { |
|
|
80 |
for (int i : b.st.keys()) |
|
|
81 |
if (a.st.contains(i)) sum += a.get(i) * b.get(i); |
|
|
82 |
} |
|
|
83 |
return sum; |
|
|
84 |
} |
|
|
85 |
|
|
|
86 |
// return the 2-norm |
|
|
87 |
public double norm() { |
|
|
88 |
SparseVector a = this; |
|
|
89 |
return Math.sqrt(a.dot(a)); |
|
|
90 |
} |
|
|
91 |
|
|
|
92 |
// return alpha * a |
|
|
93 |
public SparseVector scale(double alpha) { |
|
|
94 |
SparseVector a = this; |
|
|
95 |
SparseVector c = new SparseVector(N); |
|
|
96 |
for (int i : a.st.keys()) c.put(i, alpha * a.get(i)); |
|
|
97 |
return c; |
|
|
98 |
} |
|
|
99 |
|
|
|
100 |
// return a + b |
|
|
101 |
public SparseVector plus(SparseVector b) { |
|
|
102 |
SparseVector a = this; |
|
|
103 |
if (a.N != b.N) throw new RuntimeException("Vector lengths disagree"); |
|
|
104 |
SparseVector c = new SparseVector(N); |
|
|
105 |
for (int i : a.st.keys()) c.put(i, a.get(i)); // c = a |
|
|
106 |
for (int i : b.st.keys()) c.put(i, b.get(i) + c.get(i)); // c = c + b |
|
|
107 |
return c; |
|
|
108 |
} |
|
|
109 |
|
|
|
110 |
// return a string representation |
|
|
111 |
public String toString() { |
|
|
112 |
StringBuilder s = new StringBuilder(); |
|
|
113 |
for (int i : st.keys()) { |
|
|
114 |
s.append("(" + i + ", " + st.get(i) + ") "); |
|
|
115 |
} |
|
|
116 |
return s.toString(); |
|
|
117 |
} |
|
|
118 |
|
|
|
119 |
|
|
|
120 |
// test client |
|
|
121 |
public static void main(String[] args) { |
|
|
122 |
SparseVector a = new SparseVector(10); |
|
|
123 |
SparseVector b = new SparseVector(10); |
|
|
124 |
a.put(3, 0.50); |
|
|
125 |
a.put(9, 0.75); |
|
|
126 |
a.put(6, 0.11); |
|
|
127 |
a.put(6, 0.00); |
|
|
128 |
b.put(3, 0.60); |
|
|
129 |
b.put(4, 0.90); |
|
|
130 |
System.out.println("a = " + a); |
|
|
131 |
System.out.println("b = " + b); |
|
|
132 |
System.out.println("a dot b = " + a.dot(b)); |
|
|
133 |
System.out.println("a + b = " + a.plus(b)); |
|
|
134 |
} |
|
|
135 |
|
|
|
136 |
|
|
|
137 |
|
|
|
138 |
} |