From 087e73da477c5dac3ffa3d01fff758b607286445 Mon Sep 17 00:00:00 2001 From: maxreciprocate <56548574+maxreciprocate@users.noreply.github.com> Date: Fri, 23 Jun 2023 13:56:00 +0300 Subject: [PATCH] fix(base_trainer): force pad_token regardless of architecture --- trlx/trainer/accelerate_base_trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 6d355cec7..314eaa80c 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -73,9 +73,8 @@ def __init__(self, config, **kwargs): # noqa: C901 self.tokenizer.padding_side = config.tokenizer.padding_side self.tokenizer.truncation_side = config.tokenizer.truncation_side self.tokenizer.sep_token = "" - if config.model.model_arch_type != "seq2seq": - self.tokenizer.pad_token = self.tokenizer.eos_token - self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = "<|padding|>" script_name = os.path.basename(sys.argv[0]).rsplit(".", 1)[0] if not isinstance(config.model.model_path, str):