Skip to content

Commit

Permalink
Merge pull request #6 from xrsrke/dev
Browse files Browse the repository at this point in the history
add Prioritized Level Replay (PLR) paper
  • Loading branch information
daveey committed Mar 27, 2023
2 parents 146ff8c + 4807e30 commit d07a594
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 0 deletions.
14 changes: 14 additions & 0 deletions prioritized_level_replay/level.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
class BaseLevel:
"""Base class for a training level."""
def __init__(self, id: str) -> None:
self.id = id

def __str__(self) -> str:
return str(self.id)

def __repr__(self) -> str:
return self.__str__()

class Level(BaseLevel):
"""A training level."""
pass
84 changes: 84 additions & 0 deletions prioritized_level_replay/replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Dict, Union, List

import torch
import torch.nn.functional as F
from torch.distributions import Bernoulli
from torchtyping import TensorType

from .level import Level


class PrioritizedReplayDistribution:
"""Prioritize Replay Distribution for seen training levels."""
def __init__(
self,
staleness_coeff: float = 0.1,
temperature: float = 0.1, # the beta coefficient for the score-prioritized distribution
) -> None:
self.staleness_coeff = staleness_coeff
self.temperature = temperature

def create(
self,
score_levels: Dict[Level, Union[int, float]],
last_episode_levels: Dict[Level, int], # the last episode that each level was played,
last_episode: int # the last episode
) -> TensorType["n_visited_levels"]:
"""Create a prioritized level distribution."""

score_levels = torch.tensor([v for v in score_levels.values()])
last_episode_levels = torch.tensor([v for v in last_episode_levels.values()])

level_scores = torch.pow(
input=F.normalize(score_levels, dim=-1),
exponent=1/self.temperature
)
score_dist = level_scores / level_scores.sum(dim=-1)

stale_scores = last_episode - last_episode_levels
stale_dist = stale_scores / stale_scores.sum(dim=-1)

prioritized_dist = (1 - self.staleness_coeff) * score_dist + self.staleness_coeff * stale_dist

return prioritized_dist


class PrioritizedReplay:
def __init__(
self,
levels: List[Level],
) -> None:
self.levels = levels
self.visited_count_levels: Dict[Level, int] = {}

self.prioritized_dist = PrioritizedReplayDistribution()

def sample_next_level(
self,
visited_levels: List[Level],
score_levels: Dict[str, Union[int, float]],
last_episode_levels: Dict[str, int],
last_episode: int
) -> Level:
"""Sampling a level from the replay distribution."""
decision_dist = Bernoulli(probs=0.5)
unseen_levels = [level for level in self.levels if level not in visited_levels]

if decision_dist.sample() == 0 and len(unseen_levels) > 0:
# sample an unseen level
uniform_dist = torch.rand(len(unseen_levels))
selected_index = torch.argmax(uniform_dist)
next_level = unseen_levels[selected_index]

self.visited_count_levels[next_level] = 1
else:
# sample a level for replay
prioritized_dist = self.prioritized_dist.create(
score_levels,
last_episode_levels,
last_episode
)
visited_idx = torch.multinomial(prioritized_dist, num_samples=1)
next_level = visited_levels[visited_idx]

return next_level
20 changes: 20 additions & 0 deletions tests/test_level.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from prioritized_level_replay.level import BaseLevel, Level

def test_base_level():
id = "level_1"

level = BaseLevel(id)

assert level.id == id
assert str(level) == id
assert repr(level) == id


def test_level():
id = "level_1"

level = Level(id)

assert level.id == id
assert str(level) == id
assert repr(level) == id
50 changes: 50 additions & 0 deletions tests/test_replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch

from prioritized_level_replay.level import Level
from prioritized_level_replay.replay import (
PrioritizedReplayDistribution,
PrioritizedReplay
)

def test_prioritized_replay_distribution():
level_1 = Level("level_1")
level_2 = Level("level_2")
level_3 = Level("level_3")

score_levels = {level_1: 0.2, level_2: 0.1, level_3: 0.9}
last_episode_levels = {level_1: 3, level_2: 1, level_3: 4}
last_episode = 4
dist = PrioritizedReplayDistribution()

prioritized_dist = dist.create(score_levels, last_episode_levels, last_episode)

assert isinstance(prioritized_dist, torch.Tensor)
assert len(prioritized_dist) == len(score_levels)
assert prioritized_dist.sum(dim=-1) == 1

def test_prioritized_replay_2():
level_1 = Level("level_1")
level_2 = Level("level_2")
level_3 = Level("level_3")
level_4 = Level("level_4")
level_5 = Level("level_5")

levels = [level_1, level_2, level_3, level_4, level_5]
visited_levels = [level_1, level_3]
score_levels = {level_1: 0.2, level_3: 0.1}
last_episode_levels = {level_1: 3, level_3: 1}
last_episode = 3

replay = PrioritizedReplay(levels)

assert replay.levels == levels

next_level = replay.sample_next_level(
visited_levels,
score_levels,
last_episode_levels,
last_episode
)

assert isinstance(next_level, Level)
assert next_level in levels

0 comments on commit d07a594

Please sign in to comment.