--- a +++ b/network/SparseVector.java @@ -0,0 +1,138 @@ +/****************************************************************************** + * Compilation: javac SparseVector.java + * Execution: java SparseVector + * + * A sparse vector, implementing using a symbol table. + * + * [Not clear we need the instance variable N except for error checking.] + * + ******************************************************************************/ +package network; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Hashtable; +import java.util.Iterator; +import java.util.Set; + +public class SparseVector implements Iterable<Integer> { + private final int N; // length + private ST<Integer, Double> st; // the vector, represented by index-value pairs + + // initialize the all 0s vector of length N + + + public SparseVector(int N) { + this.N = N; + this.st = new ST<Integer, Double>(); + } + + public Set<Integer> getKeys() { + return st.getKeys(); + } + + @Override + public Iterator<Integer> iterator() { + // TODO Auto-generated method stub + return st.getKeys().iterator(); + } + + public Iterable<Integer> getKeysIter() { + return st.keys(); + } + + // put st[i] = value + public void put(int i, double value) { + if (i < 0 || i >= N) throw new RuntimeException("Illegal index"); + if (value == 0.0) st.delete(i); + else st.put(i, value); + } + + // return st[i] + public double get(int i) { + if (i < 0 || i >= N) throw new RuntimeException("Illegal index"); + if (st.contains(i)) return st.get(i); + else return 0.0; + } + + // return the number of nonzero entries + public int nnz() { + return st.size(); + } + + // return the size of the vector + public int size() { + return N; + } + + // return the dot product of this vector a with b + public double dot(SparseVector b) { + SparseVector a = this; + if (a.N != b.N) throw new RuntimeException("Vector lengths disagree"); + double sum = 0.0; + + // iterate over the vector with the fewest nonzeros + if (a.st.size() <= b.st.size()) { + for (int i : a.st.keys()) + if (b.st.contains(i)) sum += a.get(i) * b.get(i); + } + else { + for (int i : b.st.keys()) + if (a.st.contains(i)) sum += a.get(i) * b.get(i); + } + return sum; + } + + // return the 2-norm + public double norm() { + SparseVector a = this; + return Math.sqrt(a.dot(a)); + } + + // return alpha * a + public SparseVector scale(double alpha) { + SparseVector a = this; + SparseVector c = new SparseVector(N); + for (int i : a.st.keys()) c.put(i, alpha * a.get(i)); + return c; + } + + // return a + b + public SparseVector plus(SparseVector b) { + SparseVector a = this; + if (a.N != b.N) throw new RuntimeException("Vector lengths disagree"); + SparseVector c = new SparseVector(N); + for (int i : a.st.keys()) c.put(i, a.get(i)); // c = a + for (int i : b.st.keys()) c.put(i, b.get(i) + c.get(i)); // c = c + b + return c; + } + + // return a string representation + public String toString() { + StringBuilder s = new StringBuilder(); + for (int i : st.keys()) { + s.append("(" + i + ", " + st.get(i) + ") "); + } + return s.toString(); + } + + + // test client + public static void main(String[] args) { + SparseVector a = new SparseVector(10); + SparseVector b = new SparseVector(10); + a.put(3, 0.50); + a.put(9, 0.75); + a.put(6, 0.11); + a.put(6, 0.00); + b.put(3, 0.60); + b.put(4, 0.90); + System.out.println("a = " + a); + System.out.println("b = " + b); + System.out.println("a dot b = " + a.dot(b)); + System.out.println("a + b = " + a.plus(b)); + } + + + +}