--- a +++ b/ddpg/replay_buffer.py @@ -0,0 +1,34 @@ +from collections import deque +import random + +class ReplayBuffer(object): + + def __init__(self, buffer_size): + self.buffer_size = buffer_size + self.num_experiences = 0 + self.buffer = deque() + + def get_batch(self, batch_size): + # Randomly sample batch_size examples + return random.sample(self.buffer, batch_size) + + def size(self): + return self.buffer_size + + def add(self, state, action, reward, new_state, done): + experience = (state, action, reward, new_state, done) + if self.num_experiences < self.buffer_size: + self.buffer.append(experience) + self.num_experiences += 1 + else: + self.buffer.popleft() + self.buffer.append(experience) + + def count(self): + # if buffer is full, return buffer size + # otherwise, return experience counter + return self.num_experiences + + def erase(self): + self.buffer = deque() + self.num_experiences = 0 \ No newline at end of file