|
a |
|
b/baselines/common/cg.py |
|
|
1 |
import numpy as np |
|
|
2 |
def cg(f_Ax, b, cg_iters=10, callback=None, verbose=False, residual_tol=1e-10): |
|
|
3 |
""" |
|
|
4 |
Demmel p 312 |
|
|
5 |
""" |
|
|
6 |
p = b.copy() |
|
|
7 |
r = b.copy() |
|
|
8 |
x = np.zeros_like(b) |
|
|
9 |
rdotr = r.dot(r) |
|
|
10 |
|
|
|
11 |
fmtstr = "%10i %10.3g %10.3g" |
|
|
12 |
titlestr = "%10s %10s %10s" |
|
|
13 |
if verbose: print(titlestr % ("iter", "residual norm", "soln norm")) |
|
|
14 |
|
|
|
15 |
for i in range(cg_iters): |
|
|
16 |
if callback is not None: |
|
|
17 |
callback(x) |
|
|
18 |
if verbose: print(fmtstr % (i, rdotr, np.linalg.norm(x))) |
|
|
19 |
z = f_Ax(p) |
|
|
20 |
v = rdotr / p.dot(z) |
|
|
21 |
x += v*p |
|
|
22 |
r -= v*z |
|
|
23 |
newrdotr = r.dot(r) |
|
|
24 |
mu = newrdotr/rdotr |
|
|
25 |
p = r + mu*p |
|
|
26 |
|
|
|
27 |
rdotr = newrdotr |
|
|
28 |
if rdotr < residual_tol: |
|
|
29 |
break |
|
|
30 |
|
|
|
31 |
if callback is not None: |
|
|
32 |
callback(x) |
|
|
33 |
if verbose: print(fmtstr % (i+1, rdotr, np.linalg.norm(x))) # pylint: disable=W0631 |
|
|
34 |
return x |