|
a |
|
b/rdpg/replay_buffer.py |
|
|
1 |
from collections import deque |
|
|
2 |
import random |
|
|
3 |
|
|
|
4 |
class ReplayBuffer(object): |
|
|
5 |
|
|
|
6 |
def __init__(self, buffer_size): |
|
|
7 |
self.buffer_size = buffer_size |
|
|
8 |
self.num_experiences = 0 |
|
|
9 |
self.buffer = deque() |
|
|
10 |
|
|
|
11 |
def get_batch(self, batch_size): |
|
|
12 |
# Randomly sample batch_size examples |
|
|
13 |
return random.sample(self.buffer, batch_size) |
|
|
14 |
|
|
|
15 |
def size(self): |
|
|
16 |
return self.buffer_size |
|
|
17 |
|
|
|
18 |
def add(self,history): |
|
|
19 |
if self.num_experiences < self.buffer_size: |
|
|
20 |
self.buffer.append(history) |
|
|
21 |
self.num_experiences += 1 |
|
|
22 |
else: |
|
|
23 |
self.buffer.popleft() |
|
|
24 |
self.buffer.append(history) |
|
|
25 |
|
|
|
26 |
def count(self): |
|
|
27 |
# if buffer is full, return buffer size |
|
|
28 |
# otherwise, return experience counter |
|
|
29 |
return self.num_experiences |
|
|
30 |
|
|
|
31 |
def erase(self): |
|
|
32 |
self.buffer = deque() |
|
|
33 |
self.num_experiences = 0 |