-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from xrsrke/dev
add Prioritized Level Replay (PLR) paper
- Loading branch information
Showing
4 changed files
with
168 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
class BaseLevel: | ||
"""Base class for a training level.""" | ||
def __init__(self, id: str) -> None: | ||
self.id = id | ||
|
||
def __str__(self) -> str: | ||
return str(self.id) | ||
|
||
def __repr__(self) -> str: | ||
return self.__str__() | ||
|
||
class Level(BaseLevel): | ||
"""A training level.""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from typing import Dict, Union, List | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch.distributions import Bernoulli | ||
from torchtyping import TensorType | ||
|
||
from .level import Level | ||
|
||
|
||
class PrioritizedReplayDistribution: | ||
"""Prioritize Replay Distribution for seen training levels.""" | ||
def __init__( | ||
self, | ||
staleness_coeff: float = 0.1, | ||
temperature: float = 0.1, # the beta coefficient for the score-prioritized distribution | ||
) -> None: | ||
self.staleness_coeff = staleness_coeff | ||
self.temperature = temperature | ||
|
||
def create( | ||
self, | ||
score_levels: Dict[Level, Union[int, float]], | ||
last_episode_levels: Dict[Level, int], # the last episode that each level was played, | ||
last_episode: int # the last episode | ||
) -> TensorType["n_visited_levels"]: | ||
"""Create a prioritized level distribution.""" | ||
|
||
score_levels = torch.tensor([v for v in score_levels.values()]) | ||
last_episode_levels = torch.tensor([v for v in last_episode_levels.values()]) | ||
|
||
level_scores = torch.pow( | ||
input=F.normalize(score_levels, dim=-1), | ||
exponent=1/self.temperature | ||
) | ||
score_dist = level_scores / level_scores.sum(dim=-1) | ||
|
||
stale_scores = last_episode - last_episode_levels | ||
stale_dist = stale_scores / stale_scores.sum(dim=-1) | ||
|
||
prioritized_dist = (1 - self.staleness_coeff) * score_dist + self.staleness_coeff * stale_dist | ||
|
||
return prioritized_dist | ||
|
||
|
||
class PrioritizedReplay: | ||
def __init__( | ||
self, | ||
levels: List[Level], | ||
) -> None: | ||
self.levels = levels | ||
self.visited_count_levels: Dict[Level, int] = {} | ||
|
||
self.prioritized_dist = PrioritizedReplayDistribution() | ||
|
||
def sample_next_level( | ||
self, | ||
visited_levels: List[Level], | ||
score_levels: Dict[str, Union[int, float]], | ||
last_episode_levels: Dict[str, int], | ||
last_episode: int | ||
) -> Level: | ||
"""Sampling a level from the replay distribution.""" | ||
decision_dist = Bernoulli(probs=0.5) | ||
unseen_levels = [level for level in self.levels if level not in visited_levels] | ||
|
||
if decision_dist.sample() == 0 and len(unseen_levels) > 0: | ||
# sample an unseen level | ||
uniform_dist = torch.rand(len(unseen_levels)) | ||
selected_index = torch.argmax(uniform_dist) | ||
next_level = unseen_levels[selected_index] | ||
|
||
self.visited_count_levels[next_level] = 1 | ||
else: | ||
# sample a level for replay | ||
prioritized_dist = self.prioritized_dist.create( | ||
score_levels, | ||
last_episode_levels, | ||
last_episode | ||
) | ||
visited_idx = torch.multinomial(prioritized_dist, num_samples=1) | ||
next_level = visited_levels[visited_idx] | ||
|
||
return next_level |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from prioritized_level_replay.level import BaseLevel, Level | ||
|
||
def test_base_level(): | ||
id = "level_1" | ||
|
||
level = BaseLevel(id) | ||
|
||
assert level.id == id | ||
assert str(level) == id | ||
assert repr(level) == id | ||
|
||
|
||
def test_level(): | ||
id = "level_1" | ||
|
||
level = Level(id) | ||
|
||
assert level.id == id | ||
assert str(level) == id | ||
assert repr(level) == id |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import torch | ||
|
||
from prioritized_level_replay.level import Level | ||
from prioritized_level_replay.replay import ( | ||
PrioritizedReplayDistribution, | ||
PrioritizedReplay | ||
) | ||
|
||
def test_prioritized_replay_distribution(): | ||
level_1 = Level("level_1") | ||
level_2 = Level("level_2") | ||
level_3 = Level("level_3") | ||
|
||
score_levels = {level_1: 0.2, level_2: 0.1, level_3: 0.9} | ||
last_episode_levels = {level_1: 3, level_2: 1, level_3: 4} | ||
last_episode = 4 | ||
dist = PrioritizedReplayDistribution() | ||
|
||
prioritized_dist = dist.create(score_levels, last_episode_levels, last_episode) | ||
|
||
assert isinstance(prioritized_dist, torch.Tensor) | ||
assert len(prioritized_dist) == len(score_levels) | ||
assert prioritized_dist.sum(dim=-1) == 1 | ||
|
||
def test_prioritized_replay_2(): | ||
level_1 = Level("level_1") | ||
level_2 = Level("level_2") | ||
level_3 = Level("level_3") | ||
level_4 = Level("level_4") | ||
level_5 = Level("level_5") | ||
|
||
levels = [level_1, level_2, level_3, level_4, level_5] | ||
visited_levels = [level_1, level_3] | ||
score_levels = {level_1: 0.2, level_3: 0.1} | ||
last_episode_levels = {level_1: 3, level_3: 1} | ||
last_episode = 3 | ||
|
||
replay = PrioritizedReplay(levels) | ||
|
||
assert replay.levels == levels | ||
|
||
next_level = replay.sample_next_level( | ||
visited_levels, | ||
score_levels, | ||
last_episode_levels, | ||
last_episode | ||
) | ||
|
||
assert isinstance(next_level, Level) | ||
assert next_level in levels |