a b/ADDPG/replay_buffer.py
1
from collections import deque
2
import random
3
4
class ReplayBuffer(object):
5
6
    def __init__(self, buffer_size):
7
        self.buffer_size = buffer_size
8
        self.num_experiences = 0
9
        self.buffer = deque()
10
11
    def get_batch(self, batch_size):
12
        # Randomly sample batch_size examples
13
        return random.sample(self.buffer, batch_size)
14
15
    def size(self):
16
        return self.buffer_size
17
18
    def add(self,transition):
19
        if self.num_experiences < self.buffer_size:
20
            self.buffer.append(transition)
21
            self.num_experiences += 1
22
        else:
23
            self.buffer.popleft()
24
            self.buffer.append(transition)
25
26
    def count(self):
27
        # if buffer is full, return buffer size
28
        # otherwise, return experience counter
29
        return self.num_experiences
30
31
    def erase(self):
32
        self.buffer = deque()
33
        self.num_experiences = 0