Diff of /cmaes/solver.py [000000] .. [077a87]

Switch to unified view

a b/cmaes/solver.py
1
# Copyright (c) 2015, Disney Research
2
# All rights reserved.
3
#
4
# Author(s): Sehoon Ha <sehoon.ha@disneyresearch.com>
5
# Disney Research Robotics Group
6
7
8
"""
9
Example Usage:
10
11
# prob has f() and g() as member functions
12
13
solver = Solver(prob)
14
res = solver.solve()
15
x = res['x']
16
f = res['f']
17
"""
18
import cmaes.utils
19
import numpy as np
20
21
22
class Solver(object):
23
    def __init__(self, prob):
24
        self.prob = prob
25
        self.eval_counter = 0
26
        self.numerical_diff_step = 1e-4
27
        self.iter_values = list()
28
        self.verbose = True
29
        self.check_gradient = False
30
31
    def get_check_gradient(self, ):
32
        return self.check_gradient
33
34
    def set_check_gradient(self, check_gradient=True):
35
        self.check_gradient = check_gradient
36
37
    def eval_f(self, x):
38
        ret = self.prob.f(x)
39
40
        self.eval_counter += 1
41
        self.last_x = x
42
        self.last_f = ret
43
        self.iter_values.append(ret)
44
        if hasattr(self.prob, 'on_eval_f'):
45
            self.prob.on_eval_f(self)
46
47
        return ret
48
49
    def eval_g(self, x):
50
        if hasattr(self.prob, 'g'):
51
            ret = self.prob.g(x)
52
            if self.get_check_gradient():
53
                h = self.numerical_diff_step
54
                ret2 = utils.grad(self.prob.f, x, h)
55
                isgood = np.allclose(ret, ret2, atol=1e-05)
56
                if not isgood:
57
                    print(ret)
58
                    print(ret2)
59
                    print("diff = %.12f" % np.linalg.norm(ret - ret2))
60
                print('check_gradient... %s' % isgood)
61
        else:
62
            h = self.numerical_diff_step
63
            ret = utils.grad(self.prob.f, x, h)
64
        return ret
65
66
    def eval_c_eq_jac(self, x, i):
67
        if hasattr(self.prob, 'c_eq_jac'):
68
            ret = self.prob.c_eq_jac(x, i)
69
        else:
70
            h = self.numerical_diff_step
71
72
            def c_eq_f(x):
73
                return self.prob.c_eq(x, i)
74
            ret = utils.grad(c_eq_f, x, h)
75
        return ret
76
77
    def eval_c_ineq_jac(self, x, i):
78
        if hasattr(self.prob, 'c_ineq_jac'):
79
            ret1 = self.prob.c_ineq_jac(x, i)
80
            return ret1
81
        else:
82
            h = self.numerical_diff_step
83
84
            def c_ineq_f(x):
85
                return self.prob.c_ineq(x, i)
86
            ret2 = utils.grad(c_ineq_f, x, h)
87
            return ret2
88
89
    def collect_constraints(self):
90
        constraints = list()
91
        if hasattr(self.prob, 'num_eq_constraints'):
92
            num = self.prob.num_eq_constraints()
93
            if self.verbose:
94
                print('  [Solver]: num_eq_constraints = %d' % num)
95
            for i in range(num):
96
                c = dict()
97
                c['type'] = 'eq'
98
                assert(hasattr(self.prob, 'c_eq'))
99
                c['fun'] = self.prob.c_eq
100
                c['jac'] = self.eval_c_eq_jac
101
                c['args'] = [i]
102
                constraints.append(c)
103
        if hasattr(self.prob, 'num_ineq_constraints'):
104
            num = self.prob.num_ineq_constraints()
105
            if self.verbose:
106
                print('  [Solver]: num_ineq_constraints = %d' % num)
107
            for i in range(num):
108
                c = dict()
109
                c['type'] = 'ineq'
110
                assert(hasattr(self.prob, 'c_ineq'))
111
                c['fun'] = self.prob.c_ineq
112
                c['jac'] = self.eval_c_ineq_jac
113
                c['args'] = [i]
114
                constraints.append(c)
115
        if self.verbose:
116
            print('  [Solver]: num_constraints = %d' % len(constraints))
117
        return constraints
118
119
    def bounds(self):
120
        if hasattr(self.prob, 'bounds'):
121
            return self.prob.bounds()
122
        else:
123
            return None
124
125
    def solve(self, x0=None):
126
        pass
127
128
    def set_verbose(self, verbose):
129
        self.verbose = verbose
130
131
    def plot_convergence(self, filename=None):
132
        yy = self.iter_values
133
        xx = range(len(yy))
134
        import matplotlib.pyplot as plt
135
        # Plot
136
        plt.ioff()
137
        fig = plt.figure()
138
        fig.set_size_inches(18.5, 10.5)
139
        font = {'size': 28}
140
        plt.title('Value over # evaluations')
141
        plt.xlabel('X', fontdict=font)
142
        plt.ylabel('Y', fontdict=font)
143
        plt.plot(xx, yy)
144
        plt.axes().set_yscale('log')
145
        if filename is None:
146
            filename = 'plots/iter.png'
147
        plt.savefig(filename, bbox_inches='tight')
148
        plt.close(fig)
149
        print('plotting convergence OK.. ' + filename)
150
151
    def save_result(self, res, filename):
152
        with open(filename, 'w+') as fin:
153
            fin.write(str(res))
154
        print('writing result OK.. ' + filename)