from collections import deque
import random
class ReplayBuffer(object):
def __init__(self, buffer_size):
self.buffer_size = buffer_size
self.num_experiences = 0
self.buffer = deque()
def get_batch(self, batch_size):
# Randomly sample batch_size examples
return random.sample(self.buffer, batch_size)
def size(self):
return self.buffer_size
def add(self,transition):
if self.num_experiences < self.buffer_size:
self.buffer.append(transition)
self.num_experiences += 1
else:
self.buffer.popleft()
self.buffer.append(transition)
def count(self):
# if buffer is full, return buffer size
# otherwise, return experience counter
return self.num_experiences
def erase(self):
self.buffer = deque()
self.num_experiences = 0