Diff of /network/SparseVector.java [000000] .. [4fba4e]

Switch to side-by-side view

--- a
+++ b/network/SparseVector.java
@@ -0,0 +1,138 @@
+/******************************************************************************
+ *  Compilation:  javac SparseVector.java
+ *  Execution:    java SparseVector
+ *  
+ *  A sparse vector, implementing using a symbol table.
+ *
+ *  [Not clear we need the instance variable N except for error checking.]
+ *
+ ******************************************************************************/
+package network;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Hashtable;
+import java.util.Iterator;
+import java.util.Set;
+
+public class SparseVector implements Iterable<Integer> {
+    private final int N;             // length
+    private ST<Integer, Double> st;  // the vector, represented by index-value pairs
+
+    // initialize the all 0s vector of length N
+    
+    
+    public SparseVector(int N) {
+        this.N  = N;
+        this.st = new ST<Integer, Double>();
+    }
+
+    public Set<Integer> getKeys() {
+    	return st.getKeys();
+    }
+    
+    @Override
+	public Iterator<Integer> iterator() {
+		// TODO Auto-generated method stub
+		return st.getKeys().iterator();
+	}
+    
+    public Iterable<Integer> getKeysIter() {
+    	return st.keys();
+    }
+    
+    // put st[i] = value
+    public void put(int i, double value) {
+        if (i < 0 || i >= N) throw new RuntimeException("Illegal index");
+        if (value == 0.0) st.delete(i);
+        else              st.put(i, value);
+    }
+
+    // return st[i]
+    public double get(int i) {
+        if (i < 0 || i >= N) throw new RuntimeException("Illegal index");
+        if (st.contains(i)) return st.get(i);
+        else                return 0.0;
+    }
+
+    // return the number of nonzero entries
+    public int nnz() {
+        return st.size();
+    }
+
+    // return the size of the vector
+    public int size() {
+        return N;
+    }
+
+    // return the dot product of this vector a with b
+    public double dot(SparseVector b) {
+        SparseVector a = this;
+        if (a.N != b.N) throw new RuntimeException("Vector lengths disagree");
+        double sum = 0.0;
+
+        // iterate over the vector with the fewest nonzeros
+        if (a.st.size() <= b.st.size()) {
+            for (int i : a.st.keys())
+                if (b.st.contains(i)) sum += a.get(i) * b.get(i);
+        }
+        else  {
+            for (int i : b.st.keys())
+                if (a.st.contains(i)) sum += a.get(i) * b.get(i);
+        }
+        return sum;
+    }
+
+    // return the 2-norm
+    public double norm() {
+        SparseVector a = this;
+        return Math.sqrt(a.dot(a));
+    }
+
+    // return alpha * a
+    public SparseVector scale(double alpha) {
+        SparseVector a = this;
+        SparseVector c = new SparseVector(N);
+        for (int i : a.st.keys()) c.put(i, alpha * a.get(i));
+        return c;
+    }
+
+    // return a + b
+    public SparseVector plus(SparseVector b) {
+        SparseVector a = this;
+        if (a.N != b.N) throw new RuntimeException("Vector lengths disagree");
+        SparseVector c = new SparseVector(N);
+        for (int i : a.st.keys()) c.put(i, a.get(i));                // c = a
+        for (int i : b.st.keys()) c.put(i, b.get(i) + c.get(i));     // c = c + b
+        return c;
+    }
+
+    // return a string representation
+    public String toString() {
+        StringBuilder s = new StringBuilder();
+        for (int i : st.keys()) {
+            s.append("(" + i + ", " + st.get(i) + ") ");
+        }
+        return s.toString();
+    }
+
+
+    // test client
+    public static void main(String[] args) {
+        SparseVector a = new SparseVector(10);
+        SparseVector b = new SparseVector(10);
+        a.put(3, 0.50);
+        a.put(9, 0.75);
+        a.put(6, 0.11);
+        a.put(6, 0.00);
+        b.put(3, 0.60);
+        b.put(4, 0.90);
+        System.out.println("a = " + a);
+        System.out.println("b = " + b);
+        System.out.println("a dot b = " + a.dot(b));
+        System.out.println("a + b   = " + a.plus(b));
+    }
+
+	
+
+}