a b/network/SparseVector.java
1
/******************************************************************************
2
 *  Compilation:  javac SparseVector.java
3
 *  Execution:    java SparseVector
4
 *  
5
 *  A sparse vector, implementing using a symbol table.
6
 *
7
 *  [Not clear we need the instance variable N except for error checking.]
8
 *
9
 ******************************************************************************/
10
package network;
11
12
import java.util.HashMap;
13
import java.util.HashSet;
14
import java.util.Hashtable;
15
import java.util.Iterator;
16
import java.util.Set;
17
18
public class SparseVector implements Iterable<Integer> {
19
    private final int N;             // length
20
    private ST<Integer, Double> st;  // the vector, represented by index-value pairs
21
22
    // initialize the all 0s vector of length N
23
    
24
    
25
    public SparseVector(int N) {
26
        this.N  = N;
27
        this.st = new ST<Integer, Double>();
28
    }
29
30
    public Set<Integer> getKeys() {
31
        return st.getKeys();
32
    }
33
    
34
    @Override
35
    public Iterator<Integer> iterator() {
36
        // TODO Auto-generated method stub
37
        return st.getKeys().iterator();
38
    }
39
    
40
    public Iterable<Integer> getKeysIter() {
41
        return st.keys();
42
    }
43
    
44
    // put st[i] = value
45
    public void put(int i, double value) {
46
        if (i < 0 || i >= N) throw new RuntimeException("Illegal index");
47
        if (value == 0.0) st.delete(i);
48
        else              st.put(i, value);
49
    }
50
51
    // return st[i]
52
    public double get(int i) {
53
        if (i < 0 || i >= N) throw new RuntimeException("Illegal index");
54
        if (st.contains(i)) return st.get(i);
55
        else                return 0.0;
56
    }
57
58
    // return the number of nonzero entries
59
    public int nnz() {
60
        return st.size();
61
    }
62
63
    // return the size of the vector
64
    public int size() {
65
        return N;
66
    }
67
68
    // return the dot product of this vector a with b
69
    public double dot(SparseVector b) {
70
        SparseVector a = this;
71
        if (a.N != b.N) throw new RuntimeException("Vector lengths disagree");
72
        double sum = 0.0;
73
74
        // iterate over the vector with the fewest nonzeros
75
        if (a.st.size() <= b.st.size()) {
76
            for (int i : a.st.keys())
77
                if (b.st.contains(i)) sum += a.get(i) * b.get(i);
78
        }
79
        else  {
80
            for (int i : b.st.keys())
81
                if (a.st.contains(i)) sum += a.get(i) * b.get(i);
82
        }
83
        return sum;
84
    }
85
86
    // return the 2-norm
87
    public double norm() {
88
        SparseVector a = this;
89
        return Math.sqrt(a.dot(a));
90
    }
91
92
    // return alpha * a
93
    public SparseVector scale(double alpha) {
94
        SparseVector a = this;
95
        SparseVector c = new SparseVector(N);
96
        for (int i : a.st.keys()) c.put(i, alpha * a.get(i));
97
        return c;
98
    }
99
100
    // return a + b
101
    public SparseVector plus(SparseVector b) {
102
        SparseVector a = this;
103
        if (a.N != b.N) throw new RuntimeException("Vector lengths disagree");
104
        SparseVector c = new SparseVector(N);
105
        for (int i : a.st.keys()) c.put(i, a.get(i));                // c = a
106
        for (int i : b.st.keys()) c.put(i, b.get(i) + c.get(i));     // c = c + b
107
        return c;
108
    }
109
110
    // return a string representation
111
    public String toString() {
112
        StringBuilder s = new StringBuilder();
113
        for (int i : st.keys()) {
114
            s.append("(" + i + ", " + st.get(i) + ") ");
115
        }
116
        return s.toString();
117
    }
118
119
120
    // test client
121
    public static void main(String[] args) {
122
        SparseVector a = new SparseVector(10);
123
        SparseVector b = new SparseVector(10);
124
        a.put(3, 0.50);
125
        a.put(9, 0.75);
126
        a.put(6, 0.11);
127
        a.put(6, 0.00);
128
        b.put(3, 0.60);
129
        b.put(4, 0.90);
130
        System.out.println("a = " + a);
131
        System.out.println("b = " + b);
132
        System.out.println("a dot b = " + a.dot(b));
133
        System.out.println("a + b   = " + a.plus(b));
134
    }
135
136
    
137
138
}