diff --git a/trlx/pipeline/ppo_pipeline.py b/trlx/pipeline/ppo_pipeline.py index 19bae03af..2b1ea374e 100644 --- a/trlx/pipeline/ppo_pipeline.py +++ b/trlx/pipeline/ppo_pipeline.py @@ -84,7 +84,7 @@ def filter_text(d, only_text): d.pop(key) return d - data = [filter_text(exp_to_dict(exp), only_text) for exp in self.history] + data = [filter_text(exp_to_dict(exp), only_text) for exp in self.history] with open(fpath, "w") as f: f.write(json.dumps(data, indent=2))