a b/baselines/common/cg.py
1
import numpy as np
2
def cg(f_Ax, b, cg_iters=10, callback=None, verbose=False, residual_tol=1e-10):
3
    """
4
    Demmel p 312
5
    """
6
    p = b.copy()
7
    r = b.copy()
8
    x = np.zeros_like(b)
9
    rdotr = r.dot(r)
10
11
    fmtstr =  "%10i %10.3g %10.3g"
12
    titlestr =  "%10s %10s %10s"
13
    if verbose: print(titlestr % ("iter", "residual norm", "soln norm"))
14
15
    for i in range(cg_iters):
16
        if callback is not None:
17
            callback(x)
18
        if verbose: print(fmtstr % (i, rdotr, np.linalg.norm(x)))
19
        z = f_Ax(p)
20
        v = rdotr / p.dot(z)
21
        x += v*p
22
        r -= v*z
23
        newrdotr = r.dot(r)
24
        mu = newrdotr/rdotr
25
        p = r + mu*p
26
27
        rdotr = newrdotr
28
        if rdotr < residual_tol:
29
            break
30
31
    if callback is not None:
32
        callback(x)
33
    if verbose: print(fmtstr % (i+1, rdotr, np.linalg.norm(x)))  # pylint: disable=W0631
34
    return x