Diff of /ddpg/replay_buffer.py [000000] .. [687a25]

Switch to unified view

a b/ddpg/replay_buffer.py
1
from collections import deque
2
import random
3
4
class ReplayBuffer(object):
5
6
    def __init__(self, buffer_size):
7
        self.buffer_size = buffer_size
8
        self.num_experiences = 0
9
        self.buffer = deque()
10
11
    def get_batch(self, batch_size):
12
        # Randomly sample batch_size examples
13
        return random.sample(self.buffer, batch_size)
14
15
    def size(self):
16
        return self.buffer_size
17
18
    def add(self, state, action, reward, new_state, done):
19
        experience = (state, action, reward, new_state, done)
20
        if self.num_experiences < self.buffer_size:
21
            self.buffer.append(experience)
22
            self.num_experiences += 1
23
        else:
24
            self.buffer.popleft()
25
            self.buffer.append(experience)
26
27
    def count(self):
28
        # if buffer is full, return buffer size
29
        # otherwise, return experience counter
30
        return self.num_experiences
31
32
    def erase(self):
33
        self.buffer = deque()
34
        self.num_experiences = 0