Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge task conditional replays #23

Merged
merged 3 commits into from
Nov 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from dataclasses import asdict
from itertools import cycle

import dill
import numpy as np
import torch
import pandas as pd

from nmmo.render.replay_helper import FileReplayHelper
from nmmo.task.task_spec import make_task_from_spec

import pufferlib
from pufferlib.vectorization import Serial, Multiprocessing
Expand All @@ -36,9 +38,10 @@ def setup_policy_store(policy_store_dir):
policy_store = DirectoryPolicyStore(policy_store_dir)
return policy_store

def save_replays(policy_store_dir, save_dir):
def save_replays(policy_store_dir, save_dir, curriculum_file, task_to_assign=None):
# load the checkpoints into the policy store
policy_store = setup_policy_store(policy_store_dir)
policy_ranker = create_policy_ranker(policy_store_dir)
num_policies = len(policy_store._all_policies())

# setup the replay path
Expand All @@ -56,6 +59,7 @@ def save_replays(policy_store_dir, save_dir):
args.selfplay_num_policies = num_policies + 1
args.early_stop_agent_num = 0 # run the full episode
args.resilient_population = 0 # no resilient agents
args.tasks_path = curriculum_file # task-conditioning

# NOTE: This creates a dummy learner agent. Is it necessary?
from reinforcement_learning import policy # import your policy
Expand All @@ -81,6 +85,7 @@ def make_policy(envs):
selfplay_learner_weight=args.learner_weight,
selfplay_num_policies=args.selfplay_num_policies,
policy_store=policy_store,
policy_ranker=policy_ranker, # so that a new ranker is created
data_dir=save_dir,
)

Expand All @@ -97,9 +102,22 @@ def make_policy(envs):
replay_helper = FileReplayHelper()
nmmo_env = evaluator.buffers[0].envs[0].envs[0].env
nmmo_env.realm.record_replay(replay_helper)
replay_helper.reset()

if task_to_assign is not None:
with open(curriculum_file, 'rb') as f:
task_with_embedding = dill.load(f) # a list of TaskSpec
assert 0 <= task_to_assign < len(task_with_embedding), "Task index out of range"
select_task = task_with_embedding[task_to_assign]

# Assign the task to the env
tasks = make_task_from_spec(nmmo_env.possible_agents,
[select_task] * len(nmmo_env.possible_agents))
nmmo_env.tasks = tasks # this is a hack
print("seed:", args.seed,
", task:", nmmo_env.tasks[0].spec_name)

# Run an episode to generate the replay
replay_helper.reset()
while True:
with torch.no_grad():
actions, logprob, value, _ = evaluator.policy_pool.forwards(
Expand All @@ -112,10 +130,28 @@ def make_policy(envs):
o, r, d, i = evaluator.buffers[0].recv()

num_alive = len(nmmo_env.realm.players)
print('Tick:', nmmo_env.realm.tick, ", alive agents:", num_alive)
task_done = sum(1 for task in nmmo_env.tasks if task.completed)
alive_done = sum(1 for task in nmmo_env.tasks
if task.completed and task.assignee[0] in nmmo_env.realm.players)
print("Tick:", nmmo_env.realm.tick, ", alive agents:", num_alive, ", task done:", task_done)
if num_alive == alive_done:
print("All alive agents completed the task.")
break
if num_alive == 0 or nmmo_env.realm.tick == args.max_episode_length:
print("All agents died or reached the max episode length.")
break

# Count how many agents completed the task
print("--------------------------------------------------")
print("Task:", nmmo_env.tasks[0].spec_name)
num_completed = sum(1 for task in nmmo_env.tasks if task.completed)
print("Number of agents completed the task:", num_completed)
avg_progress = np.mean([task.progress_info["max_progress"] for task in nmmo_env.tasks])
print(f"Average maximum progress (max=1): {avg_progress:.3f}")
avg_completed_tick = np.mean([task.progress_info["completed_tick"]
for task in nmmo_env.tasks if task.completed])
print(f"Average completed tick: {avg_completed_tick:.1f}")

# Save the replay file
replay_file = os.path.join(save_dir, f"replay_{time.strftime('%Y%m%d_%H%M%S')}")
logging.info("Saving replay to %s", replay_file)
Expand Down Expand Up @@ -243,6 +279,8 @@ def make_policy(envs):
-s, --replay-save-dir: Directory to save replays (Default: replays/)
-r, --replay-mode: Replay save mode (Default: False)
-d, --device: Device to use for evaluation/ranking (Default: cuda if available, otherwise cpu)
-t, --task-file: Task file to use for evaluation (Default: reinforcement_learning/eval_task_with_embedding.pkl)
-i, --task-index: The index of the task to assign in the curriculum file (Default: None)

To generate replay from your checkpoints, put them together in policy_store_dir, run the following command,
and replays will be saved under the replays/. The script will only use 1 environment.
Expand Down Expand Up @@ -297,14 +335,23 @@ def make_policy(envs):
default="reinforcement_learning/eval_task_with_embedding.pkl",
help="Task file to use for evaluation",
)
parser.add_argument(
"-i",
"--task-index",
dest="task_index",
type=int,
default=None,
help="The index of the task to assign in the curriculum file",
)

# Parse and check the arguments
eval_args = parser.parse_args()
assert eval_args.policy_store_dir is not None, "Policy store directory must be specified"

if getattr(eval_args, "replay_mode", False):
logging.info("Generating replays from the checkpoints in %s", eval_args.policy_store_dir)
save_replays(eval_args.policy_store_dir, eval_args.replay_save_dir)
save_replays(eval_args.policy_store_dir, eval_args.replay_save_dir,
eval_args.task_file, eval_args.task_index)
else:
logging.info("Ranking checkpoints from %s", eval_args.policy_store_dir)
logging.info("Replays will NOT be generated")
Expand Down
6 changes: 2 additions & 4 deletions reinforcement_learning/clean_pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,8 @@ def __post_init__(self, *args, **kwargs):
# Create policy ranker
if self.policy_ranker is None:
if self.data_dir is not None:
self.policy_ranker = pufferlib.policy_ranker.OpenSkillRanker(
os.path.join(self.data_dir, "openskill.pickle"),
"anchor",
)
db_file = os.path.join(self.data_dir, "ranking.sqlite")
self.policy_ranker = pufferlib.policy_ranker.OpenSkillRanker(db_file, "anchor")
if "learner" not in self.policy_ranker.ratings():
self.policy_ranker.add_policy("learner")

Expand Down
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ openelm
pandas==2.0.3
plotly==5.15.0
psutil==5.9.3
ray==2.6.1
scikit-learn==1.3.0
tensorboard==2.11.2
tiktoken==0.4.0
torch==1.13.1
torchtyping==0.1.4
traitlets==5.9.0
transformers==4.31.0
wandb==0.13.7
4 changes: 0 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,13 @@ def curriculum_generation_track(trainer, args, use_elm=True):
if use_elm:
from curriculum_generation import manual_curriculum
from curriculum_generation.elm import OpenELMTaskGenerator
AGENT_MODEL_PATH = ""
NUM_SEED_TASKS = 20
NUM_NEW_TASKS = 5
ELM_DEBUG = True

task_encoder = TaskEncoder(LLM_CHECKPOINT, manual_curriculum, batch_size=2)
task_generator = OpenELMTaskGenerator(manual_curriculum.curriculum, LLM_CHECKPOINT)

# @daveey: We need a baseline checkpoint for this
#load_agent_model(AGENT_MODEL_PATH)

# Generating new tasks and evaluating all candidate training tasks
for _ in range(3):
# NOTE: adjust NUM_SEED_TASKS to fit your gpu
Expand Down
Loading